Replace RowVectorBatch with MatStorageT

KVCache: add ctor required for MatStorageT, remove Create; bf_pre_ffw_rms_out -> pre_ffw_rms_out
optimize_test: larger vocab_size requires more steps
shared.h: Remove unused u128 type
correctly set Activation matrix rows, avoid passing as arg
ops: pass Mat instead of pointers/sizes; vectorize LayerNorm; support any weight type
mat: add OverrideRows, used by SetBatchSize
PiperOrigin-RevId: 757790736
This commit is contained in:
Jan Wassenberg 2025-05-12 09:15:03 -07:00 committed by Copybara-Service
parent cf7dd80c17
commit 45ad847a41
39 changed files with 949 additions and 917 deletions

View File

@ -415,6 +415,7 @@ cc_library(
hdrs = ["gemma/kv_cache.h"], hdrs = ["gemma/kv_cache.h"],
deps = [ deps = [
":configs", ":configs",
":mat",
"@highway//:hwy", "@highway//:hwy",
], ],
) )
@ -425,6 +426,7 @@ cc_library(
deps = [ deps = [
":args", ":args",
":basics", ":basics",
":mat",
":ops", # matmul.h ":ops", # matmul.h
"//io", "//io",
"@highway//:hwy", "@highway//:hwy",

View File

@ -38,8 +38,8 @@ struct ForwardLayer {
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)), att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
attention_out( attention_out(
MakePacked<T>("attention_out", seq_len, config.model_dim)), MakePacked<T>("attention_out", seq_len, config.model_dim)),
bf_pre_ffw_rms_out( pre_ffw_rms_out(
MakePacked<T>("bf_preFF_rms_out", seq_len, config.model_dim)), MakePacked<T>("preFF_rms_out", seq_len, config.model_dim)),
ffw_hidden( ffw_hidden(
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)), MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
ffw_hidden_gated( ffw_hidden_gated(
@ -53,7 +53,7 @@ struct ForwardLayer {
MatStorageT<T> att_out; MatStorageT<T> att_out;
MatStorageT<T> att_post1; MatStorageT<T> att_post1;
MatStorageT<T> attention_out; MatStorageT<T> attention_out;
MatStorageT<T> bf_pre_ffw_rms_out; MatStorageT<T> pre_ffw_rms_out;
MatStorageT<T> ffw_hidden; MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated; MatStorageT<T> ffw_hidden_gated;
const LayerConfig& layer_config; const LayerConfig& layer_config;

View File

@ -170,8 +170,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<float>& forward, const ForwardLayer<float>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward, LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
const RowVectorBatch<float>& inv_timescale, const MatStorageT<float>& inv_timescale, hwy::ThreadPool& pool) {
hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config; const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim; const size_t model_dim = config.model_dim;
const size_t qkv_dim = config.qkv_dim; const size_t qkv_dim = config.qkv_dim;
@ -207,15 +206,14 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
} }
} }
MatMulVJP(weights.gating_einsum_w.Packed(), MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2,
model_dim, ff_hidden_dim * 2, num_tokens, num_tokens, grad.gating_einsum_w.Packed(),
grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(), backward.pre_ffw_rms_out.Packed(), pool);
pool); RMSNormVJP(weights.pre_ffw_norm_scale.Packed(),
RMSNormVJP( forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens, backward.attention_out.Packed(), pool);
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * model_dim, AddFrom(next_layer_grad + pos * model_dim,
@ -275,7 +273,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) { for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv = float* HWY_RESTRICT b_kv =
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim; backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos);
} }
for (size_t head = 0; head < heads; ++head) { for (size_t head = 0; head < heads; ++head) {
@ -283,7 +281,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
float* HWY_RESTRICT b_q = float* HWY_RESTRICT b_q =
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim; backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
MulByConst(query_scale, b_q, qkv_dim); MulByConst(query_scale, b_q, qkv_dim);
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos);
} }
} }
@ -342,7 +340,7 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
const ForwardPass<float>& forward, const ForwardPass<float>& forward,
ModelWeightsPtrs<T>& grad, ModelWeightsPtrs<T>& grad,
ForwardPass<float>& backward, ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config; const ModelConfig& config = weights.weights_config;
const size_t kVocabSize = config.vocab_size; const size_t kVocabSize = config.vocab_size;

View File

@ -42,7 +42,7 @@ void CrossEntropyLossBackwardPassT(const Prompt& prompt,
const ForwardPass<float>& forward, const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad, ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward, ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward, CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
inv_timescale, pool); inv_timescale, pool);
@ -62,7 +62,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ForwardPass<float>& forward, const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad, ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward, ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
prompt, weights, forward, grad, backward, inv_timescale, pool); prompt, weights, forward, grad, backward, inv_timescale, pool);

View File

@ -29,7 +29,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ForwardPass<float>& forward, const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad, ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward, ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
} // namespace gcpp } // namespace gcpp

View File

@ -218,16 +218,15 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(), GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens); backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.Packed(), MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(),
grad.gating_einsum_w.Packed(), backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
num_tokens); num_tokens);
RMSNormVJPT( RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(),
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(), grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(),
backward.attention_out.Packed(), model_dim, num_tokens); model_dim, num_tokens);
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim); AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);

View File

@ -202,7 +202,7 @@ void TestEndToEnd() {
ReverseSequenceSampler training_task({0, 0, 1, 1}); ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen); std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = CreateInvTimescale( MatStorageT<float> inv_timescale = CreateInvTimescale(
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim, ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope); config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) { for (const Prompt& prompt : batch) {

View File

@ -74,7 +74,7 @@ void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim; const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, output + offset, model_dim); RMSNorm(x + offset, weights, 0, output + offset, model_dim);
} }
} }
@ -100,7 +100,7 @@ template <typename T>
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights, void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<float>& activations, size_t num_tokens, ForwardLayer<float>& activations, size_t num_tokens,
float* HWY_RESTRICT output, float* HWY_RESTRICT output,
const RowVectorBatch<float>& inv_timescale, const MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config; const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim; const size_t model_dim = config.model_dim;
@ -125,14 +125,14 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k = float* HWY_RESTRICT k =
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim; activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, inv_timescale.Const(), pos); Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos);
} }
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads; const size_t head = task % kHeads;
const size_t pos = task / kHeads; const size_t pos = task / kHeads;
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim; activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, inv_timescale.Const(), pos); Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos);
MulByConst(query_scale, q, kQKVDim); MulByConst(query_scale, q, kQKVDim);
}); });
@ -194,11 +194,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(), ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(), model_dim, num_tokens, activations.attention_out.Packed(), model_dim, num_tokens,
activations.bf_pre_ffw_rms_out.Packed(), pool); activations.pre_ffw_rms_out.Packed(), pool);
const size_t kFFHiddenDim = config.ff_hidden_dim; const size_t kFFHiddenDim = config.ff_hidden_dim;
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim, activations.pre_ffw_rms_out.Packed() + pos * model_dim,
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool); activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
} }
for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size, size_t context_size,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
ForwardPass<float>& forward, ForwardPass<float>& forward,
const RowVectorBatch<float>& inv_timescale, const MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config; const ModelConfig& config = weights.weights_config;
const size_t vocab_size = config.vocab_size; const size_t vocab_size = config.vocab_size;

View File

@ -38,7 +38,7 @@ namespace HWY_NAMESPACE {
float CrossEntropyLossForwardPassT(const Prompt& prompt, float CrossEntropyLossForwardPassT(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights, const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward, ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size, return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
weights, forward, inv_timescale, pool); weights, forward, inv_timescale, pool);
@ -56,7 +56,7 @@ HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(const Prompt& prompt, float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights, const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward, ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
prompt, weights, forward, inv_timescale, pool); prompt, weights, forward, inv_timescale, pool);

View File

@ -19,7 +19,7 @@
#include "backprop/activations.h" #include "backprop/activations.h"
#include "backprop/prompt.h" #include "backprop/prompt.h"
#include "gemma/weights.h" #include "gemma/weights.h"
#include "util/allocator.h" #include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp { namespace gcpp {
@ -27,7 +27,7 @@ namespace gcpp {
float CrossEntropyLossForwardPass(const Prompt& prompt, float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights, const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward, ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale, MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
} // namespace gcpp } // namespace gcpp

View File

@ -219,12 +219,11 @@ void ApplyLayer(const LayerWeightsPtrs<T>& weights,
RMSNormT(weights.pre_ffw_norm_scale.Packed(), RMSNormT(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(), activations.attention_out.Packed(),
activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens); activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens);
MatMulT(weights.gating_einsum_w.Packed(), MatMulT(weights.gating_einsum_w.Packed(),
activations.bf_pre_ffw_rms_out.Packed(), activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(),
activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim, ff_hidden_dim * 2, model_dim, num_tokens);
num_tokens);
GatedGelu(activations.ffw_hidden.Packed(), GatedGelu(activations.ffw_hidden.Packed(),
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens); activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);

View File

@ -62,9 +62,9 @@ TEST(OptimizeTest, GradientDescent) {
grad_m.ZeroInit(); grad_m.ZeroInit();
grad_v.ZeroInit(); grad_v.ZeroInit();
ForwardPass<float> forward(config), backward(config); ForwardPass<float> forward(config), backward(config);
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); KVCache kv_cache(config, /*prefill_tbatch_size=*/16);
RowVectorBatch<float> inv_timescale = CreateInvTimescale( MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim, allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope); config.layer_configs[0].post_qk == PostQKType::HalfRope);
@ -147,7 +147,7 @@ TEST(OptimizeTest, GradientDescent) {
printf("Num steps: %zu\n", steps); printf("Num steps: %zu\n", steps);
printf("Final weights:\n"); printf("Final weights:\n");
gemma.MutableWeights().LogWeightStatsF32(); gemma.MutableWeights().LogWeightStatsF32();
EXPECT_LT(steps, 50); EXPECT_LT(steps, 80);
EXPECT_EQ(num_ok, kBatchSize); EXPECT_EQ(num_ok, kBatchSize);
} }

View File

@ -23,6 +23,7 @@
#include <stdio.h> #include <stdio.h>
#include <algorithm> // std::shuffle #include <algorithm> // std::shuffle
#include <array>
#include <random> #include <random>
#include "compression/distortion.h" #include "compression/distortion.h"
@ -104,7 +105,7 @@ struct TestPlateaus {
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f); HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
} }
std::random_device rd; std::random_device rd; // NOLINT
std::mt19937 rng(rd()); std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng); std::shuffle(in.get(), in.get() + kGroupSize, rng);
@ -151,7 +152,7 @@ struct TestRamp {
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f); HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
} }
std::random_device rd; std::random_device rd; // NOLINT
std::mt19937 rng(rd()); std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng); std::shuffle(in.get(), in.get() + kGroupSize, rng);
@ -246,7 +247,8 @@ struct TestOffset {
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32 auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total); auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen); auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total)); auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
HWY_ASSERT(in && dec1 && dec2 && nuq); HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total); const auto nuq_span = MakeSpan(nuq.get(), total);
@ -296,7 +298,8 @@ struct TestUnalignedOffset {
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32 auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total); auto dec1 = hwy::AllocateAligned<T>(total);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total)); auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
auto dec2 = hwy::AllocateAligned<T>(num_decompressed); auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
HWY_ASSERT(in && dec1 && dec2 && nuq); HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total); const auto nuq_span = MakeSpan(nuq.get(), total);
@ -347,7 +350,8 @@ struct TestDec2 {
auto dec0 = hwy::AllocateAligned<T>(total); auto dec0 = hwy::AllocateAligned<T>(total);
auto dec1 = hwy::AllocateAligned<T>(total); auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen); auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total)); auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq); HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total); const auto nuq_span = MakeSpan(nuq.get(), total);
@ -449,7 +453,8 @@ struct TestEncDec {
const size_t num = 4 * kGroupSize; const size_t num = 4 * kGroupSize;
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32 auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
auto out = hwy::AllocateAligned<T>(num); // already padded auto out = hwy::AllocateAligned<T>(num); // already padded
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num)); auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes()));
HWY_ASSERT(in && out && nuq); HWY_ASSERT(in && out && nuq);
const auto nuq_span = MakeSpan(nuq.get(), num); const auto nuq_span = MakeSpan(nuq.get(), num);

View File

@ -164,11 +164,11 @@ constexpr bool IsNuqStream() {
// weights for a model, but can be used for other purposes, such as types for // weights for a model, but can be used for other purposes, such as types for
// `WeightsPtrs`. When adding a new type that is supported, also // `WeightsPtrs`. When adding a new type that is supported, also
// update gemma.cc, weights.*, and add instantiations/new_one.cc. // update gemma.cc, weights.*, and add instantiations/new_one.cc.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 };
// These are used in `ModelConfig.Specifier`, hence the strings will not // These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added. // change, though new ones may be added.
static constexpr const char* kTypeStrings[] = { static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"}; "nuq", "f64", "c64"};
static constexpr size_t kNumTypes = static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {0, static constexpr size_t kTypeBits[] = {0,
@ -177,8 +177,7 @@ static constexpr size_t kTypeBits[] = {0,
8 * sizeof(SfpStream), 8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */, 4 /* NuqStream, actually 4.5 */,
8 * sizeof(double), 8 * sizeof(double),
8 * sizeof(std::complex<double>), 8 * sizeof(std::complex<double>)};
8 * sizeof(hwy::uint128_t)};
static inline bool EnumValid(Type type) { static inline bool EnumValid(Type type) {
return static_cast<size_t>(type) < kNumTypes; return static_cast<size_t>(type) < kNumTypes;
@ -200,8 +199,6 @@ Type TypeEnum() {
return Type::kF64; return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) { } else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
return Type::kC64; return Type::kC64;
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
return Type::kU128;
} else { } else {
HWY_DASSERT(false); HWY_DASSERT(false);
return Type::kUnknown; return Type::kUnknown;

View File

@ -73,8 +73,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens); size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos, std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens); prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(), KVCache kv_cache(env.GetGemma()->GetModelConfig(),
env.MutableConfig().prefill_tbatch_size); env.MutableConfig().prefill_tbatch_size);
float entropy = ComputeCrossEntropy( float entropy = ComputeCrossEntropy(
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); *env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy; total_entropy += entropy;

View File

@ -52,9 +52,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader,
const InferenceArgs& inference) const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) { : env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) {
// Only allocate one for starters because GenerateBatch might not be called. // Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1); kv_caches_.push_back(
kv_caches_[0] = KVCache(gemma_.GetModelConfig(), inference.prefill_tbatch_size));
KVCache::Create(gemma_.GetModelConfig(), inference.prefill_tbatch_size);
InitGenerator(inference, gen_); InitGenerator(inference, gen_);
@ -131,15 +130,10 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
runtime_config_.decode_qbatch_size); runtime_config_.decode_qbatch_size);
} }
// Ensure we have one KVCache per query. // Ensure we have at least one KVCache per query.
if (kv_caches_.size() < num_queries) { while (kv_caches_.size() < num_queries) {
kv_caches_.resize(num_queries); kv_caches_.push_back(
} KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size));
for (size_t i = 1; i < num_queries; ++i) {
if (kv_caches_[i].seq_len == 0) {
kv_caches_[i] = KVCache::Create(gemma_.GetModelConfig(),
runtime_config_.prefill_tbatch_size);
}
} }
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};

View File

@ -53,8 +53,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache // Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading)); gcpp::MatMulEnv env(MakeMatMulEnv(threading));
gcpp::Gemma gemma(loader, env); gcpp::Gemma gemma(loader, env);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(gemma.GetModelConfig(), gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
inference.prefill_tbatch_size);
size_t generated = 0; size_t generated = 0;
// Initialize random number generator // Initialize random number generator

View File

@ -39,11 +39,8 @@ class SimplifiedGemma {
threading_(threading), threading_(threading),
inference_(inference), inference_(inference),
env_(MakeMatMulEnv(threading_)), env_(MakeMatMulEnv(threading_)),
gemma_(loader_, env_) { gemma_(loader_, env_),
// Instantiate model and KV Cache kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
kv_cache_ = gcpp::KVCache::Create(gemma_.GetModelConfig(),
inference_.prefill_tbatch_size);
// Initialize random number generator // Initialize random number generator
std::random_device rd; std::random_device rd;
gen_.seed(rd()); gen_.seed(rd());

View File

@ -23,106 +23,127 @@
#include "ops/ops.h" // CreateInvTimescale #include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // Allocator #include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch #include "util/mat.h" // MatStorageT
namespace gcpp { namespace gcpp {
struct Activations { struct Activations {
explicit Activations(const ModelConfig& config) Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env)
: weights_config(config), : weights_config(config),
layer_config(config.layer_configs[0]), layer_config(config.layer_configs[0]),
seq_len(config.seq_len), seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()) {} cache_pos_size(config.CachePosSize()),
is_griffin(layer_config.type ==
LayerAttentionType::kGriffinRecurrentBlock),
RowVectorBatch<float> x; // input x("x", Extents2D(batch_size, config.model_dim), pad_),
RowVectorBatch<float> q; // query, also KV if MHA. q("q",
RowVectorBatch<float> logits; Extents2D(batch_size, layer_config.heads * layer_config.QStride()),
pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
// Attention pre_att_rms_out("pre_att_rms_out",
RowVectorBatch<float> pre_att_rms_out; Extents2D(batch_size, config.model_dim), pad_),
RowVectorBatch<float> att; // attention vector att("att", Extents2D(batch_size, layer_config.heads * config.seq_len),
RowVectorBatch<float> att_out; // attention output pad_),
// Accumulation of attention outputs over heads att_out(
RowVectorBatch<float> att_sums; "att_out",
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
pad_),
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_),
// Gated FFW pre_ffw_rms_out("pre_ffw_rms_out",
RowVectorBatch<BF16> bf_pre_ffw_rms_out; Extents2D(batch_size, config.model_dim), pad_),
RowVectorBatch<float> C1; C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
RowVectorBatch<float> C2; C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
RowVectorBatch<float> ffw_out; ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
// Griffin // No padding for Griffin because it does not always use Row().
RowVectorBatch<float> griffin_x; griffin_x("griffin_x",
RowVectorBatch<float> griffin_y; is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
RowVectorBatch<float> griffin_gate_x; MatPadding::kPacked),
RowVectorBatch<float> griffin_multiplier; griffin_y("griffin_y",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
griffin_gate_x(
"griffin_gate_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
griffin_multiplier(
"griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
// Rope inv_timescale(
RowVectorBatch<float> inv_timescale; CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim,
RowVectorBatch<float> inv_timescale_global; layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
env->ctx.allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
// Dynamic because no default ctor and only initialized in `Allocate`. env(env) {
MatMulEnv* env; HWY_ASSERT(batch_size != 0);
}
void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size);
q.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size);
C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size);
if (is_griffin) {
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
}
PostQKType post_qk = PostQKType::Rope;
// And the config.
const ModelConfig& weights_config; const ModelConfig& weights_config;
const LayerConfig& layer_config; const LayerConfig& layer_config;
size_t seq_len; size_t seq_len;
size_t cache_pos_size = 0; size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
bool is_griffin = false;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;
void Allocate(size_t batch_size, MatMulEnv* env) { MatStorageT<float> x; // input
const Allocator& allocator = env->ctx.allocator; MatStorageT<float> q; // query, also KV if MHA.
MatStorageT<float> logits;
post_qk = layer_config.post_qk; // Attention
const size_t model_dim = weights_config.model_dim; MatStorageT<float> pre_att_rms_out;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim; MatStorageT<float> att; // attention vector
const size_t vocab_size = weights_config.vocab_size; MatStorageT<float> att_out; // attention output
const size_t qkv_dim = layer_config.qkv_dim; // Accumulation of attention outputs over heads
const size_t heads = layer_config.heads; MatStorageT<float> att_sums;
x = RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim)); // Gated FFW
q = RowVectorBatch<float>( MatStorageT<BF16> pre_ffw_rms_out;
allocator, Extents2D(batch_size, heads * layer_config.QStride())); MatStorageT<float> C1;
if (vocab_size > 0) { MatStorageT<float> C2;
logits = MatStorageT<float> ffw_out;
RowVectorBatch<float>(allocator, Extents2D(batch_size, vocab_size));
}
pre_att_rms_out = // Griffin
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim)); MatStorageT<float> griffin_x;
att = RowVectorBatch<float>( MatStorageT<float> griffin_y;
allocator, Extents2D(batch_size, heads * weights_config.seq_len)); MatStorageT<float> griffin_gate_x;
att_out = RowVectorBatch<float>(allocator, MatStorageT<float> griffin_multiplier;
Extents2D(batch_size, heads * qkv_dim));
att_sums =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
bf_pre_ffw_rms_out = // Rope
RowVectorBatch<BF16>(allocator, Extents2D(batch_size, model_dim)); MatStorageT<float> inv_timescale;
C1 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim)); MatStorageT<float> inv_timescale_global;
C2 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
ffw_out =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { MatMulEnv* env;
griffin_x =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_y =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_gate_x =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_multiplier =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
}
inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim,
post_qk == PostQKType::HalfRope);
inv_timescale_global = CreateInvTimescale(
allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
this->env = env;
}
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -46,8 +46,7 @@ ConversationData::ConversationData(const ModelConfig& model_config,
size_t prefill_tbatch_size) size_t prefill_tbatch_size)
: model_config_ref_(model_config), : model_config_ref_(model_config),
prefill_tbatch_size_(prefill_tbatch_size), prefill_tbatch_size_(prefill_tbatch_size),
kv_cache(std::make_unique<KVCache>( kv_cache(std::make_unique<KVCache>(model_config, prefill_tbatch_size)),
KVCache::Create(model_config, prefill_tbatch_size))),
abs_pos(0) {} abs_pos(0) {}
// ConversationData copy constructor implementation // ConversationData copy constructor implementation
@ -184,25 +183,28 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
inference_args.CopyTo(runtime_config); inference_args.CopyTo(runtime_config);
size_t prefix_end = 0; size_t prefix_end = 0;
const ModelConfig& model_config = model.GetModelConfig();
// generate // generate
std::vector<int> prompt; std::vector<int> prompt;
ImageTokens image_tokens; const size_t pool_dim = model_config.vit_config.pool_dim;
ImageTokens image_tokens(
"image_tokens",
image_data
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
model_config.model_dim)
: Extents2D(0, 0),
MatPadding::kOdd);
if (image_data != nullptr) { if (image_data != nullptr) {
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
image_tokens = model_config.wrapping == PromptWrapping::GEMMA_VLM);
ImageTokens(model.Env().ctx.allocator,
Extents2D(model.GetModelConfig().vit_config.seq_len /
(pool_dim * pool_dim),
model.GetModelConfig().model_dim));
HWY_ASSERT(model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
model.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
Image image; Image image;
image.Set(image_width, image_height, static_cast<const float*>(image_data)); image.Set(image_width, image_height, static_cast<const float*>(image_data));
// We may need to resize the supplied image depending on whether we're using // We may need to resize the supplied image depending on whether we're using
// PaliGemma or Gemma 3. // PaliGemma or Gemma 3.
const size_t image_size = model.GetModelConfig().vit_config.image_size; const size_t image_size = model_config.vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
// Use the existing runtime_config defined earlier in the function. // Use the existing runtime_config defined earlier in the function.
@ -217,10 +219,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n", ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
LogDebug(ss.str().c_str()); LogDebug(ss.str().c_str());
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), prompt = WrapAndTokenize(
model.GetModelConfig().wrapping, model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,
active_conversation->abs_pos, prompt_string, active_conversation->abs_pos, prompt_string, image_tokens.Rows());
image_tokens.BatchSize());
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
@ -230,7 +231,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// Text-only case (original logic) // Text-only case (original logic)
// Use abs_pos from the active conversation // Use abs_pos from the active conversation
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.GetModelConfig().wrapping, model_config.wrapping,
active_conversation->abs_pos, prompt_string); active_conversation->abs_pos, prompt_string);
prompt_size = prompt.size(); prompt_size = prompt.size();
} }
@ -251,7 +252,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// prepare for next turn // prepare for next turn
if (!inference_args.multiturn || if (!inference_args.multiturn ||
model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) { model_config.wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently), // If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position. // reset the *active* conversation's position.
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;

View File

@ -188,8 +188,8 @@ class GemmaContext {
// rewind to initial state. // rewind to initial state.
active_conversation->abs_pos = 0; active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object // Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create( active_conversation->kv_cache = std::make_unique<KVCache>(
model.GetModelConfig(), inference_args.prefill_tbatch_size)); model.GetModelConfig(), inference_args.prefill_tbatch_size);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else { } else {

View File

@ -89,11 +89,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// X / Y linear layers. // X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
TwoMatVecAdd(layer_weights->griffin.linear_x_w, TwoMatVecAdd(layer_weights->griffin.linear_x_w,
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
activations.pre_att_rms_out.Batch(batch_idx), activations.pre_att_rms_out.Row(batch_idx),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool); /*out0=*/x, /*out1=*/y, pool);
@ -103,17 +103,16 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// Conv1D. // Conv1D.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
HWY_FULL(float) df; HWY_FULL(float) df;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0); HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t layer_offset = layer * model_dim * (conv_1d_width - 1);
// cache[i] = input at time t-i. // cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth]; float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x; cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) { for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] = cache[i] =
kv_cache.conv1d_cache.get() + layer_offset + kv_cache.conv1d_cache.Row(layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
} }
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
@ -140,12 +139,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// RGLRU // RGLRU
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx; const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx); float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx); float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
float* HWY_RESTRICT rnn_state = float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer);
kv_cache.rglru_cache.get() + layer * model_dim;
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t kHeadDim = model_dim / heads; const size_t kHeadDim = model_dim / heads;
@ -193,8 +191,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// Final linear layer. // Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
float* out_ptr = activations.att_sums.Batch(batch_idx); float* out_ptr = activations.att_sums.Row(batch_idx);
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr, layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
pool); pool);
@ -217,7 +215,7 @@ class GemmaAttention {
const float mul) { const float mul) {
// qk is either q or k, so qkv_dim is the length we operate on. // qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim; const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Const(); const float* inv_timescale = activations_.inv_timescale.Packed();
bool is_global_layer = bool is_global_layer =
activations_.weights_config.attention_window_sizes[layer] == activations_.weights_config.attention_window_sizes[layer] ==
activations_.seq_len; activations_.seq_len;
@ -227,7 +225,7 @@ class GemmaAttention {
activations_.weights_config.model == Model::GEMMA3_12B || activations_.weights_config.model == Model::GEMMA3_12B ||
activations_.weights_config.model == Model::GEMMA3_27B || activations_.weights_config.model == Model::GEMMA3_27B ||
activations_.weights_config.model == Model::GEMMA3_1B)) { activations_.weights_config.model == Model::GEMMA3_1B)) {
inv_timescale = activations_.inv_timescale_global.Const(); inv_timescale = activations_.inv_timescale_global.Packed();
} }
// PostQKType::Rope // PostQKType::Rope
(void)layer; (void)layer;
@ -249,11 +247,10 @@ class GemmaAttention {
const size_t heads = layer_config_.heads; const size_t heads = layer_config_.heads;
const size_t kv_heads = layer_config_.kv_heads; const size_t kv_heads = layer_config_.kv_heads;
const auto pre_att_rms_out = using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T;
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out); ConstMat<WeightT> w_q1(layer_weights_.qkv_einsum_w.HasPtr()
auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr() ? layer_weights_.qkv_einsum_w
? ConstMatFromWeights(layer_weights_.qkv_einsum_w) : layer_weights_.qkv_einsum_w1);
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
// We must shrink to the actual size because MatMul verifies // We must shrink to the actual size because MatMul verifies
@ -262,20 +259,19 @@ class GemmaAttention {
// computed in the second MatMul. // computed in the second MatMul.
const size_t w1_rows = heads * layer_config_.QStride(); const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows); w_q1.ShrinkRows(w1_rows);
MatMul(pre_att_rms_out, w_q1, MatMul(activations_.pre_att_rms_out, w_q1,
/*add=*/nullptr, *activations_.env, /*add=*/nullptr, *activations_.env,
RowPtrFromBatch(allocator_, activations_.q)); RowPtrFromMat(allocator_, activations_.q));
if (is_mha_) { if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else { } else {
decltype(w_q1) w_q2; decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w2);
if (layer_weights_.qkv_einsum_w.HasPtr()) { if (layer_weights_.qkv_einsum_w.HasPtr()) {
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w);
// Skip first half of the matrix. // Skip first half of the matrix.
w_q2.ofs = w_q2.Row(w1_rows); w_q2.ofs = w_q2.Row(w1_rows);
} else {
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
} }
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
@ -290,13 +286,13 @@ class GemmaAttention {
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols); RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_); kv_rows.SetStride(cache_pos_size_);
MatMul(pre_att_rms_out, w_q2, MatMul(activations_.pre_att_rms_out, w_q2,
/*add=*/nullptr, *activations_.env, kv_rows); /*add=*/nullptr, *activations_.env, kv_rows);
} else { } else {
// Proceed row by row because there will be wraparound. // Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) { ++interleaved_idx) {
const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx); const float* x = activations_.pre_att_rms_out.Row(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries_; const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_; const size_t batch_idx = interleaved_idx / num_queries_;
KVCache& kv_cache = kv_caches_[query_idx]; KVCache& kv_cache = kv_caches_[query_idx];
@ -327,15 +323,15 @@ class GemmaAttention {
// If MHA, copy computed K and V into KVCache. // If MHA, copy computed K and V into KVCache.
if (is_mha_) { if (is_mha_) {
const float* HWY_RESTRICT mha_kv = const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ + activations_.q.Row(interleaved_idx) + head * q_stride_ +
qkv_dim; qkv_dim;
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv)); hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
} }
// Apply further processing to K. // Apply further processing to K.
if (layer_weights_.key_norm_scale.HasPtr()) { if (layer_weights_.key_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv, RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(),
qkv_dim); 0, kv, qkv_dim);
} }
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
}); });
@ -402,7 +398,8 @@ class GemmaAttention {
// Apply rope and scaling to Q. // Apply rope and scaling to Q.
if (layer_weights_.query_norm_scale.HasPtr()) { if (layer_weights_.query_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim); RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q,
qkv_dim);
} }
PositionalEncodingQK(q, pos, layer_, query_scale); PositionalEncodingQK(q, pos, layer_, query_scale);
@ -435,13 +432,12 @@ class GemmaAttention {
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_; activations_.q.Row(interleaved_idx) + head * q_stride_;
float* HWY_RESTRICT att = float* HWY_RESTRICT att =
activations_.att.Batch(interleaved_idx) + activations_.att.Row(interleaved_idx) +
head * activations_.seq_len; head * activations_.seq_len;
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) + activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
head * qkv_dim;
// Make strided views into the kv cache entries for the current // Make strided views into the kv cache entries for the current
// query and head. // query and head.
@ -476,28 +472,25 @@ class GemmaAttention {
private: private:
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`). // head_dim (`qkv_dim`) into output (`layer_out`).
HWY_NOINLINE void SumHeads(const size_t num_interleaved) { HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.Attention.SumHeads"); PROFILER_ZONE("Gen.Attention.SumHeads");
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// layer_config_.qkv_dim. Thus the [num_interleaved, // layer_config_.qkv_dim. Thus the [num_interleaved,
// layer_config_.model_dim] matmul output is the sum over heads. Compare // layer_config_.model_dim] matmul output is the sum over heads. Compare
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
// encoded) // encoded)
HWY_DASSERT(layer_config_.model_dim > 0); HWY_DASSERT(layer_config_.model_dim != 0 && layer_config_.heads != 0 &&
HWY_DASSERT(layer_config_.heads > 0); layer_config_.qkv_dim != 0);
HWY_DASSERT(layer_config_.qkv_dim > 0);
HWY_DASSERT(layer_weights_.att_weights.HasPtr()); HWY_DASSERT(layer_weights_.att_weights.HasPtr());
HWY_DASSERT(activations_.att_out.All() != nullptr); HWY_DASSERT(activations_.att_out.HasPtr());
HWY_DASSERT(activations_.att_sums.All() != nullptr); HWY_DASSERT(activations_.att_sums.HasPtr());
const float* add = const float* add =
layer_weights_.layer_config.softmax_attn_output_biases layer_weights_.layer_config.softmax_attn_output_biases
? layer_weights_.attention_output_biases.PackedScale1() ? layer_weights_.attention_output_biases.PackedScale1()
: nullptr; : nullptr;
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out), MatMul(activations_.att_out, layer_weights_.att_weights, add,
ConstMatFromWeights(layer_weights_.att_weights), add, *activations_.env, RowPtrFromMat(allocator_, activations_.att_sums));
*activations_.env,
RowPtrFromBatch(allocator_, activations_.att_sums));
} }
public: public:
@ -524,7 +517,7 @@ class GemmaAttention {
const size_t num_interleaved = num_tokens_ * num_queries_; const size_t num_interleaved = num_tokens_ * num_queries_;
ComputeQKV(num_interleaved); ComputeQKV(num_interleaved);
DotSoftmaxWeightedSum(num_interleaved); DotSoftmaxWeightedSum(num_interleaved);
SumHeads(num_interleaved); SumHeads();
} }
private: private:
@ -618,12 +611,11 @@ class VitAttention {
HWY_NOINLINE void ComputeQKV() { HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV"); PROFILER_ZONE("Gen.VitAttention.QKV");
auto& qkv = activations_.q; auto& qkv = activations_.q;
HWY_ASSERT(qkv.BatchSize() == num_tokens_); HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out), MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
RowPtrFromBatch(allocator_, qkv)); RowPtrFromMat(allocator_, qkv));
} }
// TODO(philculliton): transition fully to MatMul. // TODO(philculliton): transition fully to MatMul.
@ -635,52 +627,49 @@ class VitAttention {
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim)); const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents) // Shift Q, K, VT to MatStorageT.
RowVectorBatch<float> Q = MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
AllocateAlignedRows<float>(allocator_, Extents2D(num_tokens_, qkv_dim)); MatPadding::kPacked);
RowVectorBatch<float> K = MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim),
AllocateAlignedRows<float>(allocator_, Extents2D(seq_len, qkv_dim)); MatPadding::kPacked);
RowVectorBatch<float> C(allocator_, Extents2D(num_tokens_, seq_len)); MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
MatPadding::kPacked);
// Initialize att_out to zero prior to head loop. // Initialize att_out to zero prior to head loop.
hwy::ZeroBytes(activations_.att_out.All(), ZeroInit(activations_.att_out);
num_tokens_ * heads * qkv_dim * sizeof(float));
for (size_t head = 0; head < heads; ++head) { for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t token = task; const size_t token = task;
float* HWY_RESTRICT q = float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim;
activations_.q.Batch(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working // TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim); MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float)); hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
}); });
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task; const size_t seq_idx = task;
float* HWY_RESTRICT k = float* HWY_RESTRICT k =
activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim; activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float)); hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
}); });
// this produces C, a (num_tokens_, seq_len) matrix of dot products // this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMul(ConstMatFromBatch(Q.BatchSize(), Q), MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C));
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
RowPtrFromBatch(allocator_, C));
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Batch(task); float* HWY_RESTRICT c = C.Row(task);
Softmax(c, C.Cols()); Softmax(c, C.Cols());
}); });
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
size_t token = task; size_t token = task;
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim; activations_.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = float* HWY_RESTRICT v =
activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim; activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim); MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
} }
}); });
} }
@ -701,24 +690,24 @@ class VitAttention {
const size_t token = task / layer_config_.heads; const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att. // Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q = float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim; activations_.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim); MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att = float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * activations_.seq_len; activations_.att.Row(token) + head * activations_.seq_len;
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k = float* HWY_RESTRICT k =
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim; activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k head_att[i] = Dot(q, k, qkv_dim); // score = q.k
} }
// SoftMax yields "probabilities" in head_att. // SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len); Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out. // Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out = float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim; activations_.att_out.Row(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.q.Batch(i) + float* HWY_RESTRICT v =
head * 3 * qkv_dim + 2 * qkv_dim; activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
} }
}); });
@ -732,10 +721,9 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length // att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads. // matmul output is the sum over heads.
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums); *activations_.env, att_sums);
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
} }
public: public:
@ -771,7 +759,7 @@ class VitAttention {
template <typename T> template <typename T>
HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
T* HWY_RESTRICT c2, size_t count) { const T* HWY_RESTRICT c2, size_t count) {
PROFILER_ZONE("Gen.Activation"); PROFILER_ZONE("Gen.Activation");
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<T>; using DF = hn::ScalableTag<T>;
@ -787,12 +775,38 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
}); });
} }
// No C2 multiplier.
template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1) {
using T = typename Mat::T;
for (size_t i = 0; i < c1.Rows(); ++i) {
// Cast to correct type so type deduction works.
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
c1.Cols());
}
}
template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) {
using T = typename Mat::T;
HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) {
for (size_t i = 0; i < c1.Rows(); ++i) {
Activation(activation, c1.Row(i), c2->Row(i), c1.Cols());
}
} else { // No multiplier
for (size_t i = 0; i < c1.Rows(); ++i) {
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
c1.Cols());
}
}
}
template <typename T> template <typename T>
HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, HWY_NOINLINE void FFWNoVit(Activations& activations,
const LayerWeightsPtrs<T>* layer_weights) { const LayerWeightsPtrs<T>* layer_weights) {
PROFILER_ZONE("Gen.FFW"); PROFILER_ZONE("Gen.FFW");
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const bool add_bias = layer_weights->layer_config.ff_biases; const bool add_bias = layer_weights->layer_config.ff_biases;
const float* bias1 = const float* bias1 =
@ -802,56 +816,48 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations. // Define slightly more readable names for the weights and activations.
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator& allocator = activations.env->ctx.allocator; const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
auto multiplier = RowPtrFromBatch(allocator, activations.C2); auto multiplier = RowPtrFromMat(allocator, activations.C2);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
// gating_einsum_w holds two half-matrices. We plan to change the importer to // gating_einsum_w holds two half-matrices. We plan to change the importer to
// avoid this confusion by splitting into gating_einsum_w1 and // avoid this confusion by splitting into gating_einsum_w1 and
// gating_einsum_w2. // gating_einsum_w2. TODO: move into Reshape().
const bool split = layer_weights->gating_einsum_w.HasPtr(); const bool split = layer_weights->gating_einsum_w.HasPtr();
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w) ConstMat<WeightT> w1(split ? layer_weights->gating_einsum_w
: ConstMatFromWeights(layer_weights->gating_einsum_w1); : layer_weights->gating_einsum_w1);
decltype(w1) w2; ConstMat<WeightT> w2(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w2);
if (split) { if (split) {
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w);
w2.ofs = w2.Row(ffh_hidden_dim); w2.ofs = w2.Row(ffh_hidden_dim);
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that. // Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
w1.ShrinkRows(ffh_hidden_dim); w1.ShrinkRows(ffh_hidden_dim);
w2.ShrinkRows(ffh_hidden_dim); w2.ShrinkRows(ffh_hidden_dim);
} else {
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2);
} }
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
// Compute the hidden layer activations. // Compute the hidden layer activations.
MatMul(x, w1, bias1, *activations.env, hidden_activations); MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env,
MatMul(x, w2, bias2, *activations.env, multiplier); hidden_activations);
MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier);
// Activation (Gelu) and maybe multiply by gate. Store activations in act. // Activation (Gelu) and maybe multiply by gate. Store activations in act.
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), ActivationBatched(layer_weights->layer_config.activation, activations.C1,
multiplier.Row(0), ffh_hidden_dim * num_interleaved); &activations.C2);
// Hidden layer -> output layer. // Hidden layer -> output layer.
auto activations_mat = MakeConstMat( MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim), ffw_out);
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
} }
// Same as FFWNoVit, but with different layer_weights members and no second // Same as FFWNoVit, but with different layer_weights members and no second
// gating matrix. // gating matrix.
template <typename T> template <typename T>
HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, HWY_NOINLINE void FFWVit(Activations& activations,
const LayerWeightsPtrs<T>* layer_weights) { const LayerWeightsPtrs<T>* layer_weights) {
PROFILER_ZONE("Gen.FFW"); PROFILER_ZONE("Gen.FFW.ViT");
const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const bool add_bias = layer_weights->layer_config.ff_biases; const bool add_bias = layer_weights->layer_config.ff_biases;
const float* bias1 = const float* bias1 =
@ -860,30 +866,21 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations. // Define slightly more readable names for the weights and activations.
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator& allocator = activations.env->ctx.allocator; const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
// Compute the hidden layer activations. // Compute the hidden layer activations.
MatMul(x, w1, bias1, *activations.env, hidden_activations); MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
*activations.env, hidden_activations);
// Activation (Gelu), store in act. // Activation (Gelu), store in act.
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0); RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), ActivationBatched(layer_weights->layer_config.activation, activations.C1);
multiplier.Row(0), ff_hidden_dim * num_interleaved);
// Hidden layer -> output layer. // Hidden layer -> output layer.
auto activations_mat = MakeConstMat(hidden_activations.Row(0), MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
Extents2D(num_interleaved, ff_hidden_dim), *activations.env, ffw_out);
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
} }
// `batch_idx` indicates which row of `x` to write to. // `batch_idx` indicates which row of `x` to write to.
@ -898,23 +895,23 @@ template <typename T>
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt, size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x, MatStorageT<float>& x,
const ImageTokens* image_tokens, const ImageTokens* image_tokens,
size_t& image_token_position) { size_t& image_token_position) {
// Image tokens just need to be copied. // Image tokens just need to be copied.
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM && if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 && image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->BatchSize()) { image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Batch(image_token_position), hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0])); x.Cols() * x.ElementBytes());
image_token_position++; image_token_position++;
return; return;
} }
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA && if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx), hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
x.Cols() * sizeof(x.Const()[0])); x.Cols() * x.ElementBytes());
return; return;
} }
@ -934,12 +931,12 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim); HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim);
const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0), const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0),
embedding_ofs + model_dim); embedding_ofs + model_dim);
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Batch(batch_idx), DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
model_dim); model_dim);
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(), MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
x.Batch(batch_idx), model_dim); x.Row(batch_idx), model_dim);
if (weights.weights_config.absolute_pe) { if (weights.weights_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos); AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
} }
} }
@ -951,29 +948,28 @@ template <typename T>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt, size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x, MatStorageT<float>& x,
const ImageTokens* image_tokens) { const ImageTokens* image_tokens) {
size_t image_token_position = 0; size_t image_token_position = 0;
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x, EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
image_tokens, image_token_position); image_tokens, image_token_position);
} }
template <typename Weights, typename T> template <typename T2, class LayerWeights>
HWY_NOINLINE void ResidualConnection( HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
size_t num_interleaved, const T* HWY_RESTRICT other, T* HWY_RESTRICT x, MatPtrT<float>& HWY_RESTRICT x,
const LayerWeightsPtrs<Weights>* layer_weights, bool is_attention) { const LayerWeights* layer_weights,
bool is_attention) {
// ResidualType::Add // ResidualType::Add
AddFromBatched(num_interleaved, other, x, AddFromBatched(other, x);
layer_weights->layer_config.model_dim);
} }
template <typename WeightT, typename InOutT> template <typename WeightT, typename InOutT>
void PostNorm(PostNormType post_norm, size_t num_interleaved, void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
const WeightT& weights, InOutT* inout) { MatPtrT<InOutT>& inout) {
HWY_DASSERT(weights.Rows() == 1); HWY_DASSERT(weights.Rows() == 1);
if (post_norm == PostNormType::Scale) { if (post_norm == PostNormType::Scale) {
RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout, RMSNormInplaceBatched(weights, inout);
weights.Cols());
} }
} }
@ -985,39 +981,33 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
Activations& activations, Activations& activations,
const hwy::Divisor& div_seq_len, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) { const KVCaches& kv_caches) {
const size_t model_dim = activations.weights_config.model_dim;
const size_t num_interleaved = num_tokens * queries_pos.size();
auto type = layer_weights->layer_config.type; auto type = layer_weights->layer_config.type;
RMSNormBatched(num_interleaved, activations.x.All(), RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale,
layer_weights->pre_attention_norm_scale.PackedScale1(), activations.pre_att_rms_out);
activations.pre_att_rms_out.All(), model_dim);
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx, Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
activations, layer_weights, div_seq_len, kv_caches); activations, layer_weights, div_seq_len, kv_caches);
PostNorm(layer_weights->layer_config.post_norm, num_interleaved, PostNorm(layer_weights->layer_config.post_norm,
layer_weights->post_attention_norm_scale, layer_weights->post_attention_norm_scale, activations.att_sums);
activations.att_sums.All());
ResidualConnection(num_interleaved, activations.att_sums.All(), ResidualConnection(activations.att_sums, activations.x, layer_weights,
activations.x.All(), layer_weights, /*is_attention=*/true); /*is_attention=*/true);
RMSNormBatched(num_interleaved, activations.x.All(), RMSNormBatched(activations.x, layer_weights->pre_ffw_norm_scale,
layer_weights->pre_ffw_norm_scale.PackedScale1(), activations.pre_ffw_rms_out);
activations.bf_pre_ffw_rms_out.All(), model_dim);
if (layer_weights->layer_config.type == LayerAttentionType::kVit) { if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
FFWVit(activations, num_interleaved, layer_weights); FFWVit(activations, layer_weights);
} else { } else {
FFWNoVit(activations, num_interleaved, layer_weights); FFWNoVit(activations, layer_weights);
} }
PostNorm(layer_weights->layer_config.post_norm, num_interleaved, PostNorm(layer_weights->layer_config.post_norm,
layer_weights->post_ffw_norm_scale, activations.ffw_out.All()); layer_weights->post_ffw_norm_scale, activations.ffw_out);
ResidualConnection(num_interleaved, activations.ffw_out.All(), ResidualConnection(activations.ffw_out, activations.x, layer_weights,
activations.x.All(), layer_weights,
/*is_attention=*/false); /*is_attention=*/false);
} }
@ -1034,38 +1024,37 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
auto type = layer_weights->layer_config.type; auto type = layer_weights->layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit); HWY_DASSERT(type == LayerAttentionType::kVit);
(void)type; (void)type;
(void)model_dim;
auto& x = activations.x; auto& x = activations.x;
HWY_DASSERT(x.BatchSize() == num_tokens); HWY_DASSERT(x.Rows() == num_tokens);
HWY_DASSERT(x.Cols() == model_dim); HWY_DASSERT(x.Cols() == model_dim);
// y = nn.LayerNorm()(x) // y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out // y ~ pre_att_rms_out
LayerNormBatched(num_tokens, x.All(), LayerNormBatched(x, layer_weights->vit.layer_norm_0_scale,
layer_weights->vit.layer_norm_0_scale.PackedScale1(), layer_weights->vit.layer_norm_0_bias,
layer_weights->vit.layer_norm_0_bias.PackedScale1(), activations.pre_att_rms_out);
activations.pre_att_rms_out.All(), model_dim);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums // y ~ att_sums
VitAttention<T>(num_tokens, layer, activations, layer_weights)(); VitAttention<T>(num_tokens, layer, activations, layer_weights)();
// x = out["+sa"] = x + y // x = out["+sa"] = x + y
AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), model_dim); AddFromBatched(activations.att_sums, x);
// y = nn.LayerNorm()(x) // y = nn.LayerNorm()(x)
// y ~ bf_pre_ffw_rms_out // y ~ pre_ffw_rms_out
LayerNormBatched(num_tokens, x.All(), LayerNormBatched(x, layer_weights->vit.layer_norm_1_scale,
layer_weights->vit.layer_norm_1_scale.PackedScale1(), layer_weights->vit.layer_norm_1_bias,
layer_weights->vit.layer_norm_1_bias.PackedScale1(), activations.pre_ffw_rms_out);
activations.bf_pre_ffw_rms_out.All(), model_dim);
// y = out["mlp"] = MlpBlock(...)(y) // y = out["mlp"] = MlpBlock(...)(y)
// y ~ ffw_out // y ~ ffw_out
FFWVit(activations, num_tokens, layer_weights); FFWVit(activations, layer_weights);
// x = out["+mlp"] = x + y // x = out["+mlp"] = x + y
AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), model_dim); AddFromBatched(activations.ffw_out, x);
} }
// Prefill() and Transformer() increment positions in-place. // Prefill() and Transformer() increment positions in-place.
@ -1094,7 +1083,7 @@ HWY_NOINLINE void Prefill(
// intensity, and so we are eventually compute-limited. We could devote some // intensity, and so we are eventually compute-limited. We could devote some
// threads to parallelizing over queries, but for simplicity we assign them // threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul. // all to MatMul.
const size_t max_tbatch_size = activations.x.BatchSize(); const size_t max_tbatch_size = activations.x.Rows();
// For each query. `qi` is within the batch, not the global query index. // For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) { for (size_t qi = 0; qi < num_queries; ++qi) {
@ -1131,6 +1120,7 @@ HWY_NOINLINE void Prefill(
tbatch_start += max_tbatch_size) { tbatch_start += max_tbatch_size) {
const size_t tbatch_size = const size_t tbatch_size =
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
activations.SetBatchSize(tbatch_size);
// Fill activations.x (much faster than TransformerLayer). // Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0; size_t image_token_position = 0;
@ -1201,13 +1191,14 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3)
// This could be done as one MatMul like: // This could be done as one MatMul like:
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize); // MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
// kPatchSize), MatPadding::kPacked);
// [Get patches] // [Get patches]
// MatMul( // MatMul(
// MatFromBatch(kVitSeqLen, image_patches), // MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel), // MatFromWeights(weights.vit_img_embedding_kernel),
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env, // weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
// RowPtrF(activations.x.All(), kVitModelDim)); // RowPtrF(activations.x.Row(0), kVitModelDim));
// However, MatMul currently requires that // However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0 // A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and // which is not the case here. We should relax that requirement on MatMul and
@ -1216,11 +1207,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(), image_patches[i].get(),
weights.vit_img_embedding_bias.PackedScale1(), weights.vit_img_embedding_bias.PackedScale1(),
activations.x.Batch(i), activations.env->ctx.pools.Pool(0)); activations.x.Row(i), activations.env->ctx.pools.Pool(0));
} }
// Add position embeddings. // Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(), AddFromBatched(weights.vit_img_pos_embedding, activations.x);
seq_len * model_dim);
} }
// Prefills the image tokens with the ViT encoder. // Prefills the image tokens with the ViT encoder.
@ -1232,7 +1222,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
PROFILER_ZONE("Gen.PrefillVit"); PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_config.seq_len; const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim; const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.BatchSize()); HWY_ASSERT(num_tokens == activations.x.Rows());
// Embed the image patches. // Embed the image patches.
EmbedImagePatches(image, weights, activations); EmbedImagePatches(image, weights, activations);
// Go through all layers. // Go through all layers.
@ -1243,24 +1233,21 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
VitTransformerLayer(num_tokens, layer, layer_weights, activations); VitTransformerLayer(num_tokens, layer, layer_weights, activations);
} }
// Final Layernorm. // Final Layernorm.
LayerNormBatched(num_tokens, activations.x.All(), LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
weights.vit_encoder_norm_scale.PackedScale1(), weights.vit_encoder_norm_bias, activations.x);
weights.vit_encoder_norm_bias.PackedScale1(),
activations.x.All(), vit_model_dim);
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x); activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection. // Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(), RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0,
vit_model_dim); activations.x.Row(0), vit_model_dim);
} }
// Apply head embedding into image_tokens of size of the LLM kModelDim. // Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x), MatMul(activations.x, weights.vit_img_head_kernel,
ConstMatFromWeights(weights.vit_img_head_kernel),
weights.vit_img_head_bias.PackedScale1(), *activations.env, weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromBatch(activations.env->ctx.allocator, image_tokens)); RowPtrFromMat(activations.env->ctx.allocator, image_tokens));
} }
// Generates one token for each query. `queries_token` is the previous token // Generates one token for each query. `queries_token` is the previous token
@ -1272,7 +1259,6 @@ HWY_NOINLINE void Transformer(
Activations& activations, const hwy::Divisor& div_seq_len, Activations& activations, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, const LayersOutputFunc& layers_output, const KVCaches& kv_caches, const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) { const ActivationsObserverFunc& activations_observer) {
const size_t model_dim = weights.weights_config.model_dim;
const size_t num_queries = queries_token.size(); const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries); HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries); HWY_DASSERT(queries_prefix_end.size() == num_queries);
@ -1302,8 +1288,7 @@ HWY_NOINLINE void Transformer(
} }
} }
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(), RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
activations.x.All(), model_dim);
if (activations_observer) { if (activations_observer) {
activations_observer(queries_pos, -1, activations); activations_observer(queries_pos, -1, activations);
@ -1395,18 +1380,18 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
runtime_config.activations_observer); runtime_config.activations_observer);
// queries_pos are incremented by Transformer. // queries_pos are incremented by Transformer.
HWY_DASSERT(num_queries == activations.x.Rows());
bool all_queries_eos = true; bool all_queries_eos = true;
{ {
PROFILER_ZONE("Gen.EmbeddingMatmul"); PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations. // Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x), MatMul(activations.x, weights.embedder_input_embedding,
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env, /*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.env->ctx.allocator, activations.logits)); RowPtrFromMat(activations.env->ctx.allocator, activations.logits));
} }
PROFILER_ZONE("Gen.Softcap+Sample+Stream"); PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size); const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated(); timing_info.NotifyGenerated();
@ -1460,7 +1445,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.BatchSize()); HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(queries_pos_in.size() == num_queries); HWY_ASSERT(queries_pos_in.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries); HWY_ASSERT(kv_caches.size() == num_queries);
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len)); const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
@ -1475,12 +1460,11 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
// If tbatch is larger than the qbatch we already have in `activations`, then // If tbatch is larger than the qbatch we already have in `activations`, then
// allocate prefill_activations, otherwise reuse. // allocate prefill_activations, otherwise reuse.
const bool use_prefill_activations = const bool use_prefill_activations =
runtime_config.prefill_tbatch_size > activations.x.BatchSize(); runtime_config.prefill_tbatch_size > activations.x.Rows();
Activations prefill_activations(weights.weights_config); Activations prefill_activations(
if (use_prefill_activations) { weights.weights_config,
prefill_activations.Allocate(runtime_config.prefill_tbatch_size, use_prefill_activations ? runtime_config.prefill_tbatch_size : 0,
activations.env); activations.env);
}
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights, query_idx_start, weights,
use_prefill_activations ? prefill_activations : activations, use_prefill_activations ? prefill_activations : activations,
@ -1534,8 +1518,7 @@ void GenerateSingleT(const ModelStore& model,
const size_t qbatch_start = 0; const size_t qbatch_start = 0;
// TODO: move into Gemma? // TODO: move into Gemma?
Activations activations(model.Config()); Activations activations(model.Config(), kNumQueries, env);
activations.Allocate(kNumQueries, env);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries);
@ -1558,7 +1541,7 @@ void GenerateBatchT(const ModelStore& model,
TimingInfo& timing_info) { TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size(); const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries); HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries); HWY_ASSERT(kv_caches.size() >= num_queries);
// Griffin does not support query batching. // Griffin does not support query batching.
size_t max_qbatch_size = runtime_config.decode_qbatch_size; size_t max_qbatch_size = runtime_config.decode_qbatch_size;
for (const LayerConfig& layer_config : model.Config().layer_configs) { for (const LayerConfig& layer_config : model.Config().layer_configs) {
@ -1568,14 +1551,14 @@ void GenerateBatchT(const ModelStore& model,
} }
} }
Activations activations(model.Config()); Activations activations(model.Config(), max_qbatch_size, env);
activations.Allocate(max_qbatch_size, env);
for (size_t qbatch_start = 0; qbatch_start < num_queries; for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) { qbatch_start += max_qbatch_size) {
// Generate one batch of tokens from `qbatch_size` queries. // Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size = const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size); HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
activations.SetBatchSize(qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size); qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
@ -1601,8 +1584,7 @@ void GenerateImageTokensT(const ModelStore& model,
ModelConfig vit_config = GetVitConfig(model.Config()); ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size = prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config); Activations prefill_activations(vit_config, vit_config.seq_len, env);
prefill_activations.Allocate(vit_config.seq_len, env);
// Weights are for the full PaliGemma model, not just the ViT part. // Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(weights, prefill_runtime_config, image, image_tokens, PrefillVit(weights, prefill_runtime_config, image, image_tokens,
prefill_activations); prefill_activations);

View File

@ -32,7 +32,6 @@
#include "ops/matmul.h" // MatMulEnv #include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h" #include "paligemma/image.h"
#include "util/basics.h" // TokenAndProb #include "util/basics.h" // TokenAndProb
#include "util/mat.h" // RowVectorBatch
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/timer.h" #include "hwy/timer.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports

View File

@ -25,10 +25,11 @@
#include <random> #include <random>
#include <string> #include <string>
#include "io/io.h" // Path #include "io/io.h" // Path
#include "ops/matmul.h" // MMStorage::kMax* #include "ops/matmul.h" // MMStorage::kMax*
#include "util/args.h" #include "util/args.h"
#include "util/basics.h" // Tristate #include "util/basics.h" // Tristate
#include "util/mat.h"
#include "hwy/aligned_allocator.h" // Span #include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT #include "hwy/base.h" // HWY_ABORT
@ -74,9 +75,9 @@ using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>; using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>; using QueriesPos = hwy::Span<const size_t>;
// ImageTokens are represented as a RowVectorBatch, where each "batch" index // ImageTokens are represented as a matrix, where each row corresponds
// corresponds to a token for an image patch as computed by the image encoder. // to a token for an image patch as computed by the image encoder.
using ImageTokens = RowVectorBatch<float>; using ImageTokens = MatStorageT<float>;
// StreamFunc is called with (token, probability). For prompt tokens, // StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and // probability is 0.0f. StreamFunc should return false to stop generation and

View File

@ -15,91 +15,69 @@
#include "gemma/kv_cache.h" #include "gemma/kv_cache.h"
#include <algorithm> #include <algorithm> // std::copy
#include "gemma/configs.h" #include "gemma/configs.h"
#include "util/mat.h" // ZeroInit
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes #include "hwy/base.h" // ZeroBytes
namespace gcpp { namespace gcpp {
void KVCache::ZeroGriffinCache() { void KVCache::ZeroGriffinCache() {
if (conv1d_cache_size != 0) { if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache);
hwy::ZeroBytes(conv1d_cache.get(), if (rglru_cache.HasPtr()) ZeroInit(rglru_cache);
conv1d_cache_size * sizeof(conv1d_cache[0])); }
}
if (rglru_cache_size != 0) { static size_t GriffinConv1dCols(const ModelConfig& config) {
hwy::ZeroBytes(rglru_cache.get(), size_t conv1d_width = 0;
rglru_cache_size * sizeof(rglru_cache[0])); for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
} }
return conv1d_width == 0 ? 0 : conv1d_width - 1;
} }
// prefill_tbatch_size is the maximum number of tokens from one query to // prefill_tbatch_size is the maximum number of tokens from one query to
// prefill at a time. // prefill at a time.
KVCache KVCache::Create(const ModelConfig& weights_config, KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
size_t prefill_tbatch_size) { : griffin_layers(
KVCache kv_cache = {}; config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
griffin_conv1d_cols(GriffinConv1dCols(config)),
const size_t size_cache_pos = weights_config.CachePosSize(); // TODO(patrickms): Add query batching support for Griffin.
conv1d_cache(
"conv1d_cache",
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
MatPadding::kOdd),
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
MatPadding::kOdd) {
// TODO: move to MatStorageT.
const size_t size_cache_pos = config.CachePosSize();
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
// Allocate more so that prefill can always access one batch, even if // Allocate more so that prefill can always access one batch, even if
// near the end of the sequence. // near the end of the sequence.
kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size; seq_len = config.seq_len + prefill_tbatch_size;
kv_cache.kv_cache = kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
} }
const size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
// TODO(patrickms): Add query batching support for Griffin.
if (num_griffin_layers > 0) {
uint32_t conv1d_width = 0;
for (const auto& layer_config : weights_config.layer_configs) {
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
}
const size_t conv1d_cache_size =
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
weights_config.model_dim;
kv_cache.conv1d_cache_size = conv1d_cache_size;
if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
}
const size_t rglru_cache_size =
num_griffin_layers * weights_config.model_dim;
kv_cache.rglru_cache_size = rglru_cache_size;
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
}
} // num_griffin_layers
return kv_cache;
} }
KVCache KVCache::Copy(const ModelConfig& weights_config, KVCache KVCache::Copy(const ModelConfig& weights_config,
size_t prefill_tbatch_size) { size_t prefill_tbatch_size) {
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size); KVCache copy(weights_config, prefill_tbatch_size);
const size_t size_cache_pos = weights_config.CachePosSize(); const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) { if (size_cache_pos != 0) {
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
kv_cache_copy.kv_cache.get()); copy.kv_cache.get());
} }
const size_t num_griffin_layers = weights_config.NumLayersOfType( if (conv1d_cache.HasPtr()) {
LayerAttentionType::kGriffinRecurrentBlock); CopyMat(conv1d_cache, copy.conv1d_cache);
if (num_griffin_layers > 0) {
if (conv1d_cache_size != 0) {
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
kv_cache_copy.conv1d_cache.get());
}
if (rglru_cache_size != 0) {
std::copy(rglru_cache.get(),
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
kv_cache_copy.rglru_cache.get());
}
} }
return kv_cache_copy; if (rglru_cache.HasPtr()) {
CopyMat(rglru_cache, copy.rglru_cache);
}
return copy;
} }
} // namespace gcpp } // namespace gcpp

View File

@ -19,33 +19,31 @@
#include <stddef.h> #include <stddef.h>
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "util/mat.h"
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
namespace gcpp { namespace gcpp {
struct KVCache { struct KVCache {
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size KVCache() = default; // for std::vector.
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2 // Returns a deep copy of the KVCache.
hwy::AlignedFreeUniquePtr<float[]> kv_cache; KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
size_t conv1d_cache_size = 0;
// kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
size_t rglru_cache_size = 0;
size_t griffin_layers = 0;
size_t griffin_conv1d_cols = 0;
// griffin_layers, griffin_conv1d_cols * config.model_dim
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
// and rglru_cache. // and rglru_cache.
void ZeroGriffinCache(); void ZeroGriffinCache();
static KVCache Create(const ModelConfig& weights_config, size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
size_t prefill_tbatch_size);
// Returns a deep copy of the KVCache. // seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); hwy::AlignedFreeUniquePtr<float[]> kv_cache;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -95,24 +95,25 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
size_t abs_pos = 0; // across turns size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0; size_t prompt_size = 0;
const ModelConfig& config = gemma.GetModelConfig();
std::mt19937 gen; std::mt19937 gen;
InitGenerator(inference, gen); InitGenerator(inference, gen);
const bool have_image = !inference.image_file.path.empty(); const bool have_image = !inference.image_file.path.empty();
Image image; Image image;
ImageTokens image_tokens; const size_t pool_dim = config.vit_config.pool_dim;
ImageTokens image_tokens(
"image_tokens",
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
config.model_dim)
: Extents2D(0, 0),
MatPadding::kOdd);
if (have_image) { if (have_image) {
size_t pool_dim = gemma.GetModelConfig().vit_config.pool_dim; HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
image_tokens = config.wrapping == PromptWrapping::GEMMA_VLM);
ImageTokens(gemma.Env().ctx.allocator,
Extents2D(gemma.GetModelConfig().vit_config.seq_len /
(pool_dim * pool_dim),
gemma.GetModelConfig().model_dim));
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
gemma.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(image.ReadPPM(inference.image_file.path)); HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = gemma.GetModelConfig().vit_config.image_size; const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size); image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &gen, RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity, .verbosity = inference.verbosity,
@ -138,7 +139,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cerr << "." << std::flush; std::cerr << "." << std::flush;
} }
return true; return true;
} else if (gemma.GetModelConfig().IsEOS(token)) { } else if (config.IsEOS(token)) {
if (inference.verbosity >= 2) { if (inference.verbosity >= 2) {
std::cout << "\n[ End ]\n"; std::cout << "\n[ End ]\n";
} }
@ -191,8 +192,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
size_t prefix_end = 0; size_t prefix_end = 0;
if (have_image) { if (have_image) {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
gemma.GetModelConfig().wrapping, abs_pos, config.wrapping, abs_pos, prompt_string,
prompt_string, image_tokens.BatchSize()); image_tokens.Rows());
runtime_config.image_tokens = &image_tokens; runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size(); prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma. // The end of the prefix for prefix-LM style attention in Paligemma.
@ -203,8 +204,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// runtime_config.prefill_tbatch_size = prompt_size; // runtime_config.prefill_tbatch_size = prompt_size;
} else { } else {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
gemma.GetModelConfig().wrapping, abs_pos, config.wrapping, abs_pos, prompt_string);
prompt_string);
prompt_size = prompt.size(); prompt_size = prompt.size();
} }
@ -228,8 +228,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
} }
// Prepare for the next turn. Works only for PaliGemma. // Prepare for the next turn. Works only for PaliGemma.
if (!inference.multiturn || if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0. abs_pos = 0; // Start a new turn at position 0.
InitGenerator(inference, gen); InitGenerator(inference, gen);
} else { } else {
@ -254,8 +253,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
MatMulEnv env(MakeMatMulEnv(threading)); MatMulEnv env(MakeMatMulEnv(threading));
if (inference.verbosity >= 2) env.print_best = true; if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, env); const Gemma gemma(loader, env);
KVCache kv_cache = KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
KVCache::Create(gemma.GetModelConfig(), inference.prefill_tbatch_size);
if (inference.verbosity >= 1) { if (inference.verbosity >= 1) {
std::string instructions = std::string instructions =

View File

@ -92,9 +92,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Extents2D B_extents(N, K); // already transposed const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N); const Extents2D C_extents(M, N);
RowVectorBatch<TC> c_slow_batch = MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
AllocateAlignedRows<TC>(allocator, C_extents); MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked); MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
if (add) { if (add) {
@ -104,11 +103,9 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool); MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool);
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool); MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
const auto A = ConstMatFromWeights(a);
const auto B = ConstMatFromWeights(b_trans);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch); const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
// Fewer reps for large batch sizes, which take longer. // Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12; const size_t num_samples = M < 32 ? 20 : 12;
@ -118,7 +115,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and // Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling // spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op. // BindB/C if there is a single package: they will be a no-op.
BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel); BindB(allocator, B_extents.rows, sizeof(TC), ConstMat<TB>(b_trans),
env.parallel);
BindC(allocator, A_extents.rows, C, env.parallel); BindC(allocator, A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault; Tristate use_spinning = Tristate::kDefault;
@ -133,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Until enough samples collected *after* autotuning finished: // Until enough samples collected *after* autotuning finished:
while (times.size() < num_samples) { while (times.size() < num_samples) {
const double t0 = hwy::platform::Now(); const double t0 = hwy::platform::Now();
per_key = MatMul(A, B, add_row, env, C); per_key = MatMul(a, b_trans, add_row, env, C);
const double t1 = hwy::platform::Now(); const double t1 = hwy::platform::Now();
double elapsed = t1 - t0; double elapsed = t1 - t0;
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]); keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);

View File

@ -26,6 +26,7 @@
#include <cmath> #include <cmath>
#include <random> #include <random>
#include "compression/compress.h"
#include "compression/shared.h" #include "compression/shared.h"
#include "util/allocator.h" #include "util/allocator.h"
#include "util/test_util.h" #include "util/test_util.h"
@ -999,7 +1000,6 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot const hn::ScalableTag<float> df; // for CallDot
const Allocator& allocator = gcpp::ThreadingContext::Get().allocator;
CompressWorkingSet work; CompressWorkingSet work;
std::mt19937 rng; std::mt19937 rng;
rng.seed(12345); rng.seed(12345);
@ -1010,22 +1010,22 @@ struct TestShortDotsT {
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`, // GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
// hence they require padding to one vector. // hence they require padding to one vector.
const size_t padded_num = hwy::RoundUpTo(num, N); const size_t padded_num = hwy::RoundUpTo(num, N);
const size_t packed_num = CompressedArrayElements<Packed>(num); MatStorageT<float> raw_w("raw_w", padded_num);
RowVectorBatch<float> raw_w(allocator, Extents2D(1, padded_num)); MatStorageT<float> raw_v("raw_v", padded_num);
RowVectorBatch<float> raw_v(allocator, Extents2D(1, padded_num)); MatStorageT<Packed> weights("weights", padded_num);
RowVectorBatch<Packed> weights(allocator, Extents2D(1, packed_num)); const PackedSpan<Packed> w = weights.Span();
const PackedSpan<Packed> w(weights.Batch(0), packed_num); MatStorageT<T> vectors("vectors", padded_num);
RowVectorBatch<T> vectors(allocator, Extents2D(1, num)); const PackedSpan<T> v = vectors.Span();
const PackedSpan<T> v(vectors.Batch(0), num);
RowVectorBatch<double> bufs(allocator, Extents2D(1, num)); MatStorageT<double> bufs("bufs", num);
double* HWY_RESTRICT buf = bufs.Batch(0); double* HWY_RESTRICT buf = bufs.Packed();
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
GenerateWellConditionedInputs(num, raw_w.All(), rng, w, work); GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work);
GenerateWellConditionedInputs(num, raw_v.All(), rng, v, work); GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work);
const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf); const float dot_exact =
ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf);
float dots[kVariants]; float dots[kVariants];
for (size_t variant = 0; variant < kVariants; ++variant) { for (size_t variant = 0; variant < kVariants; ++variant) {
// Here Packed is not always float, so we must not call kDouble. // Here Packed is not always float, so we must not call kDouble.
@ -1106,7 +1106,6 @@ void TestAllDot() {
threading_args.max_lps = kMaxWorkers - 1; threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
ThreadingContext& ctx = ThreadingContext::Get(); ThreadingContext& ctx = ThreadingContext::Get();
const Allocator& allocator = ctx.allocator;
{ // ensure no profiler zones are active { // ensure no profiler zones are active
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
@ -1118,16 +1117,17 @@ void TestAllDot() {
constexpr size_t kReps = hn::AdjustedReps(40); constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024; const size_t num = 24 * 1024;
RowVectorBatch<float> a(allocator, Extents2D(kMaxWorkers, num)); MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
RowVectorBatch<float> b(allocator, Extents2D(kMaxWorkers, num)); MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
RowVectorBatch<double> bufs(allocator, Extents2D(kMaxWorkers, num)); MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num),
MatPadding::kOdd);
std::array<DotStats, kMaxWorkers> all_stats; std::array<DotStats, kMaxWorkers> all_stats;
ctx.pools.Cluster(0, 0).Run( ctx.pools.Cluster(0, 0).Run(
0, kReps, [&](const uint32_t rep, size_t thread) { 0, kReps, [&](const uint32_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Batch(thread); float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Batch(thread); float* HWY_RESTRICT pb = b.Row(thread);
double* HWY_RESTRICT buf = bufs.Batch(thread); double* HWY_RESTRICT buf = bufs.Row(thread);
const PackedSpan<const float> a_span(pa, num); const PackedSpan<const float> a_span(pa, num);
DotStats& stats = all_stats[thread]; DotStats& stats = all_stats[thread];
const double cond = const double cond =

View File

@ -693,7 +693,6 @@ class MMScaleDemoteAdd {
// We manually unroll 2x for higher IPC in batch=1. // We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin(); size_t col_c = range_nc.begin();
if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) {
HWY_UNROLL(1)
for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) {
VD a0, a1; // unused if !kAdd VD a0, a1; // unused if !kAdd
if constexpr (kAdd) { if constexpr (kAdd) {
@ -860,9 +859,8 @@ class MMScaleDemoteAdd {
class MMPerPackage { class MMPerPackage {
public: public:
template <typename TA> template <typename TA>
MMPerPackage(const ConstMat<TA>& A, const MMArgs& args, MMPerPackage(const MatPtrT<TA>& A, const MMArgs& args, const MMConfig& config,
const MMConfig& config, size_t pkg_idx, size_t pkg_idx, const IndexRange& range_np)
const IndexRange& range_np)
: args_(args), : args_(args),
pkg_idx_(pkg_idx), pkg_idx_(pkg_idx),
// May be overwritten with a view of A, if already BF16. // May be overwritten with a view of A, if already BF16.
@ -1114,12 +1112,12 @@ class MMPerPackage {
}); });
} }
// Decompresses all `M x K` from `A` into `pkg_A`. Assumes `TA` is a seekable // Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable
// type (i.e., not NUQ) so we can use pointer arithmetic. // type (i.e., not NUQ) so we can use pointer arithmetic.
template <typename TA> template <typename TA>
HWY_NOINLINE void DoDecompressA(const ConstMat<TA>& A, MMParA par_a) const { HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
const IndexRange all_M(0, A.extents.rows); const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.extents.cols); const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_.Cols()); HWY_DASSERT(all_K.Num() == A_.Cols());
const hn::ScalableTag<BF16> dbf; const hn::ScalableTag<BF16> dbf;
@ -1131,10 +1129,9 @@ class MMPerPackage {
const size_t col0 = range_K.begin(); const size_t col0 = range_K.begin();
const size_t cols = range_K.Num(); const size_t cols = range_K.Num();
// otherwise, padding overwrites neighbors // otherwise, padding overwrites neighbors
HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols); HWY_DASSERT(cols % NBF == 0 || cols == A.Cols());
for (size_t row_a : range_M) { for (size_t row_a : range_M) {
const PackedSpan<const TA> from = const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0; BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols); DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded. // Verify that we zero-padded.
@ -1174,18 +1171,14 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`. // Autotuning wrapper for `DoDecompressA`.
template <typename TA> template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const { HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
const Allocator& allocator = args_.env->ctx.allocator; const Allocator& allocator = args_.env->ctx.allocator;
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_]; MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view: // If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) { if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if no zero-padding required. // Only if no zero-padding required.
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>()); const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.extents.cols % NBF == 0)) { if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A);
const BF16* pos = A.ptr + A.Row(0);
return RowPtrBF(allocator, const_cast<BF16*>(pos), A.extents.cols,
A.Stride());
}
} }
if (HWY_LIKELY(autotune.Best())) { if (HWY_LIKELY(autotune.Best())) {
@ -1196,7 +1189,7 @@ class MMPerPackage {
// First call: generate candidates. // First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) { if (HWY_UNLIKELY(!autotune.HasCandidates())) {
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4}; std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4};
if (A.extents.rows == 1) { if (A.Rows() == 1) {
candidates.push_back(MMParA::kNone); candidates.push_back(MMParA::kNone);
} else { } else {
candidates.push_back(MMParA::kM); candidates.push_back(MMParA::kM);
@ -1247,7 +1240,7 @@ class MMPerPackage {
const MMArgs args_; // copy for locality const MMArgs args_; // copy for locality
const size_t pkg_idx_; const size_t pkg_idx_;
RowPtrBF A_; // points into A or storage. RowPtrBF A_; // points into A or pkg_A.
const IndexRange range_np_; const IndexRange range_np_;
// From MMConfig: // From MMConfig:
@ -1276,9 +1269,8 @@ struct MMImpl {
// Called from `MatMul` from two places: either with the next autotune config, // Called from `MatMul` from two places: either with the next autotune config,
// or with the best config. // or with the best config.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A, static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
const ConstMat<TB>& B, const RowPtr<TC>& C, const RowPtr<TC>& C, const MMArgs& args,
const MMArgs& args,
const MMConfig& config) { const MMConfig& config) {
MMZone matmul_zone; MMZone matmul_zone;
matmul_zone.MaybeEnter("MM.DoMatMul", args); matmul_zone.MaybeEnter("MM.DoMatMul", args);
@ -1313,7 +1305,7 @@ struct MMImpl {
// //
// Uses considerable stack space: at least 40 KiB per thread. // Uses considerable stack space: at least 40 KiB per thread.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) { const RowPtr<TC>& C) {
const Allocator& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
@ -1340,7 +1332,7 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
MMPerKey& per_key = env.per_key[index]; MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune; MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add, const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.scale, add,
env.storage.Partial()); env.storage.Partial());
if (HWY_LIKELY(tuner.Best())) { if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
@ -1396,6 +1388,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
return &per_key; return &per_key;
} }
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
return MatMul(A, ConstMat<TB>(B), add, env, C);
}
// NOLINTNEXTLINE(google-readability-namespace-comments) // NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
} // namespace gcpp } // namespace gcpp

View File

@ -21,6 +21,7 @@
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <memory> // std::unique_ptr
#include <vector> #include <vector>
// IWYU pragma: begin_exports // IWYU pragma: begin_exports
@ -207,24 +208,28 @@ class MMStorage {
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024; static constexpr size_t kMaxKC = 8 * 1024;
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel) MMStorage(const Allocator& allocator, MMParallel& parallel)
// Per-worker copies of `partial` would be wasteful. We instead allocate // Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at // one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity. // false-sharing-free granularity.
: partial_storage_( : partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
AllocateAlignedRows<double>(allocator, Extents2D(kMaxM, kMaxN))), MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind. // Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(allocator, partial_storage_.All(), kMaxN, partial_(allocator, partial_storage_.Row(0), kMaxN,
StrideForCyclicOffsets(kMaxN, allocator.Quantum<double>())) { partial_storage_.Stride()) {
// Per-package allocation so each can decompress A into its own copy. // Per-package allocation so each can decompress A into its own copy.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx] = pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
AllocateAlignedRows<BF16>(allocator, Extents2D(kMaxM, kMaxK)); "pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
if (allocator.ShouldBind()) { if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx); const size_t node = parallel.Node(pkg_idx);
if (!allocator.BindMemory(pkg_A_[pkg_idx].All(), size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
pkg_A_[pkg_idx].NumBytes(), node)) { pkg_A_[pkg_idx]->ElementBytes();
bytes = hwy::RoundDownTo(bytes, allocator.QuantumBytes());
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx); HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
} }
} }
@ -234,22 +239,20 @@ class MMStorage {
BindC(allocator, kMaxM, partial_, parallel); BindC(allocator, kMaxM, partial_, parallel);
} }
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is // Returns per-package matrix view.
// non-const, because `RowPtr` requires a non-const pointer.
RowPtrBF A(const Allocator& allocator, size_t pkg_idx, RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
const Extents2D& extents) { const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK); HWY_DASSERT(extents.cols <= kMaxK);
const size_t stride = return RowPtrBF(allocator, const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
StrideForCyclicOffsets(extents.cols, allocator.Quantum<BF16>()); extents.cols, pkg_A_[pkg_idx]->Stride());
return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride);
} }
RowPtrD Partial() const { return partial_; } RowPtrD Partial() const { return partial_; }
private: private:
RowVectorBatch<BF16> pkg_A_[MMParallel::kMaxPackages]; std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
RowVectorBatch<double> partial_storage_; MatStorageT<double> partial_storage_;
RowPtrD partial_; RowPtrD partial_;
}; };
@ -608,6 +611,8 @@ struct MMPerKey {
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`. // `MatMulEnv`.
struct MatMulEnv { struct MatMulEnv {
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext`.
explicit MatMulEnv(ThreadingContext& ctx); explicit MatMulEnv(ThreadingContext& ctx);
ThreadingContext& ctx; ThreadingContext& ctx;
@ -679,8 +684,8 @@ struct MMZone {
#endif // PROFILER_ENABLED #endif // PROFILER_ENABLED
// Used for the A and B arguments of `MatMul`, which are always const. // Used for the A and B arguments of `MatMul`, which are always const.
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the // This differs from `RowPtr` in supporting the `ofs` required for compressed T.
// `ofs` required for compressed T. // TODO: remove after splitting W1/W2 and updating QDotK to RowPtr.
template <typename T> template <typename T>
struct ConstMat { struct ConstMat {
ConstMat() = default; ConstMat() = default;
@ -689,6 +694,12 @@ struct ConstMat {
HWY_DASSERT(ptr != nullptr); HWY_DASSERT(ptr != nullptr);
HWY_DASSERT(stride >= extents.cols); HWY_DASSERT(stride >= extents.cols);
} }
// Non-explicit so that we can pass `MatPtr` directly to MatMul.
ConstMat(const MatPtrT<T>& m)
: ConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride()) {
scale = m.Scale();
}
size_t Row(size_t r) const { size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) { if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) { if (r >= extents.rows) {
@ -727,31 +738,6 @@ struct ConstMat {
size_t ofs; size_t ofs;
}; };
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
size_t stride) {
return ConstMat<T>(ptr, extents, stride);
}
// For A argument to MatMul (activations).
template <typename T>
ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols()),
row_vectors.Stride());
}
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
ConstMat<T> mat =
MakeConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride());
mat.scale = m.Scale();
return mat;
}
template <typename TB> template <typename TB>
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC, void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
const ConstMat<TB>& B, MMParallel& parallel) { const ConstMat<TB>& B, MMParallel& parallel) {

View File

@ -57,10 +57,10 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
// Returns 1-norm, used for estimating tolerable numerical differences. // Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const RowVectorBatch<float>& a) { double MaxRowAbsSum(const MatStorageT<float>& a) {
double max_row_abs_sum = 0.0; double max_row_abs_sum = 0.0;
for (size_t r = 0; r < a.BatchSize(); r++) { for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Batch(r); const float* row = a.Row(r);
double row_abs_sum = 0.0; double row_abs_sum = 0.0;
for (size_t c = 0; c < a.Cols(); c++) { for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]); row_abs_sum += hwy::ScalarAbs(row[c]);
@ -71,11 +71,11 @@ double MaxRowAbsSum(const RowVectorBatch<float>& a) {
} }
// Returns the maximum absolute value of `a`. // Returns the maximum absolute value of `a`.
float MaxAbs(const RowVectorBatch<float>& a) { float MaxAbs(const MatStorageT<float>& a) {
float max_abs = 0.0f; float max_abs = 0.0f;
for (size_t c = 0; c < a.Cols(); c++) { for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < a.BatchSize(); r++) { for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Batch(r); const float* row = a.Row(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c])); max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
} }
} }
@ -84,33 +84,29 @@ float MaxAbs(const RowVectorBatch<float>& a) {
// B is already transposed. // B is already transposed.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B, void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) { const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const Allocator& allocator = ThreadingContext::Get().allocator;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t cols = A.extents.cols; const size_t cols = A.Cols();
const size_t B_rows = B.extents.rows; const size_t B_rows = B.Rows();
// Round up for DecompressAndZeroPad. // Round up for DecompressAndZeroPad.
RowVectorBatch<float> a_batch = MatStorageT<float> a_batch("a_batch", A.Extents(), MatPadding::kOdd);
AllocateAlignedRows<float>(allocator, A.extents); MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
RowVectorBatch<float> b_trans_batch = MatPadding::kOdd);
AllocateAlignedRows<float>(allocator, B.extents); MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
RowVectorBatch<float> c_batch = MatPadding::kOdd);
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows)); MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
RowVectorBatch<float> c_slow_batch = MatPadding::kOdd);
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows)); for (size_t m = 0; m < A.Rows(); ++m) {
HWY_ASSERT(A.ofs == 0 && B.ofs == 0); DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
for (size_t m = 0; m < A.extents.rows; ++m) { DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
a_batch.Batch(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
B_rows); B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0, DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Batch(m), B_rows); c_slow_batch.Row(m), B_rows);
} }
for (size_t n = 0; n < B_rows; ++n) { for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0, DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n),
b_trans_batch.Batch(n), cols); cols);
} }
// MatMul rounds inputs to BF16, so error is proportional to the max input // MatMul rounds inputs to BF16, so error is proportional to the max input
@ -130,10 +126,10 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
} }
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>()); const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
for (size_t r = 0; r < A.extents.rows; r++) { for (size_t r = 0; r < A.Rows(); r++) {
const float* expected_row = c_slow_batch.Batch(r); const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Batch(r); const float* actual_row = c_batch.Row(r);
for (size_t c = 0; c < B.extents.rows; c++) { for (size_t c = 0; c < B.Rows(); c++) {
const double expected_value = static_cast<double>(expected_row[c]); const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]); const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value && const bool in_range = expected_value - tolerance <= actual_value &&
@ -157,18 +153,17 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
// B is already transposed. // B is already transposed.
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B, HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const float* HWY_RESTRICT add_row, MatMulEnv& env, const float* HWY_RESTRICT add_row, MatMulEnv& env,
const RowPtr<TC>& C) { const RowPtr<TC>& C) {
// TA can be any Packed except NuqStream because it uses pointer // TA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not // arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs. // support a v_ofs.
static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32"); static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32");
const float scale = A.scale * B.scale; const float scale = A.Scale() * B.Scale();
const hn::ScalableTag<float> df; // lane type is ignored const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const TB> b_span = const PackedSpan<const TB> b_span = B.Span();
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
const IndexRange all_rows_c(0, A.Extents().rows); const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols()); const IndexRange all_cols_c(0, C.Cols());
@ -191,8 +186,8 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
for (size_t c : cols_c) { for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f; const float add = add_row ? add_row[c] : 0.0f;
C_row[c] = hwy::ConvertScalarTo<TC>( C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.Stride(), add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r),
A.ptr + A.Row(r), A.extents.cols)); A.Cols()));
} }
} }
}); });
@ -225,26 +220,23 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool)); MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool)); MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
RowVectorBatch<TC> c_slow_batch = MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
AllocateAlignedRows<TC>(allocator, C_extents); MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
MatStorageT<float> add_storage = MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool) add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked); : MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
add_storage.SetScale(1.0f); add_storage.SetScale(1.0f);
const auto A = ConstMatFromWeights(a);
const auto B = ConstMatFromWeights(b_trans);
const float* add_row = add ? add_storage.PackedScale1() : nullptr; const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromBatch(allocator, c_slow_batch); const RowPtr<TC> C_slow = RowPtrFromMat(allocator, c_slow_batch);
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch); const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
MatMulSlow(A, B, add_row, env, C_slow); MatMulSlow(a, b_trans, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths. // A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) { for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(A, B, add_row, env, C); MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C);
AssertClose(A, B, C_slow, C, line); AssertClose(a, b_trans, C_slow, C, line);
if (per_key->autotune.Best()) break; if (per_key->autotune.Best()) break;
} }
} }

View File

@ -189,10 +189,11 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
} // namespace detail } // namespace detail
template <typename VecT, typename WeightT, typename OutT> // `x_ofs` is the offset within `x`, required for NuqStream.
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x, template <typename XT, typename WT, typename OT>
const WeightT* HWY_RESTRICT weight, HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
OutT* HWY_RESTRICT out, const WT* HWY_RESTRICT weight,
size_t w_ofs, OT* HWY_RESTRICT out,
const size_t size) { const size_t size) {
PROFILER_FUNC; PROFILER_FUNC;
@ -203,17 +204,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
const VF mul = hn::Set(df, detail::RMSNormMul(x, size)); const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
const auto packed_w = MakeSpan(weight, size); const auto packed_x = MakeSpan(x, size);
const auto packed_v = MakeSpan(x, size); const auto packed_w = MakeSpan(weight, w_ofs + size);
const auto packed_out = MakeSpan(out, size); const auto packed_out = MakeSpan(out, size);
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0); HWY_DASSERT(size % (2 * NF) == 0);
for (size_t i = 0; i < size; i += 2 * NF) { for (size_t i = 0; i < size; i += 2 * NF) {
VF v0, v1, w0, w1; VF x0, x1, w0, w1;
Decompress2(df, packed_v, i, v0, v1); Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_w, i, w0, w1); Decompress2(df, packed_w, w_ofs + i, w0, w1);
const VF m0 = hn::Mul(mul, v0); const VF m0 = hn::Mul(mul, x0);
const VF m1 = hn::Mul(mul, v1); const VF m1 = hn::Mul(mul, x1);
// (1+weight) * m = m + weight*m = one FMA. // (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0); const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1); const VF out1 = hn::MulAdd(m1, w1, m1);
@ -222,10 +223,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
} }
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
template <typename WeightT, typename VecT> template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout, size_t w_ofs,
const size_t size) { XT* HWY_RESTRICT inout,
const size_t size) {
PROFILER_FUNC; PROFILER_FUNC;
namespace hn = hwy::HWY_NAMESPACE; namespace hn = hwy::HWY_NAMESPACE;
@ -235,72 +237,112 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size)); const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
const auto packed_w = MakeSpan(weight, size); const auto packed_w = MakeSpan(weight, w_ofs + size);
const auto packed_v = MakeSpan(inout, size); const auto packed_x = MakeSpan(inout, size);
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0); HWY_DASSERT(size % (2 * NF) == 0);
for (size_t i = 0; i < size; i += 2 * NF) { for (size_t i = 0; i < size; i += 2 * NF) {
VF v0, v1, w0, w1; VF x0, x1, w0, w1;
Decompress2(df, MakeConst(packed_v), i, v0, v1); Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_w, i, w0, w1); Decompress2(df, packed_w, w_ofs + i, w0, w1);
const VF m0 = hn::Mul(mul, v0); const VF m0 = hn::Mul(mul, x0);
const VF m1 = hn::Mul(mul, v1); const VF m1 = hn::Mul(mul, x1);
// (1+weight) * m = m + weight*m = one FMA. // (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0); const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1); const VF out1 = hn::MulAdd(m1, w1, m1);
Compress2(df, out0, out1, packed_v, i); Compress2(df, out0, out1, packed_x, i);
} }
} }
// Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm. // Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm.
template <typename T> template <typename XT>
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu, HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size,
T& mu2) { double& mu, double& mu2) {
HWY_ASSERT(size > 0); HWY_ASSERT(size > 0);
double sum = 0.0; const hn::ScalableTag<float> df;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) { // Use the existing Sum and Dot kernels for simplicity. The second pass
const float f = hwy::ConvertScalarTo<float>(a[i]); // is likely not too expensive because it will be in L1.
sum += f; const double sum = Sum(df, x, size);
sum2 += f * f; // We only have one array, so calling `DecompressAndCall` instead of `Dot``
} // avoids loading the 'second' vector again.
mu = sum / size; const double sum2 =
mu2 = sum2 / size; DecompressAndCall(df, MakeSpan(x, size), DotKernelDouble());
const double inv_size = 1.0 / static_cast<double>(size);
mu = sum * inv_size;
mu2 = sum2 * inv_size;
} }
// Compare py/flax/linen/normalization.py. // Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias // out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename VecT, typename WeightT, typename OutT> // x and out may be the same.
HWY_NOINLINE void ScalarLayerNorm(const VecT* x, template <typename XT, typename WT, typename OT>
const WeightT* HWY_RESTRICT scale, HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
const WeightT* HWY_RESTRICT bias, const WT* HWY_RESTRICT bias, OT* out, size_t size) {
OutT* out,
size_t size) {
constexpr float kEps = 1e-6f;
VecT mu, mu2;
ScalarMus(x, size, mu, mu2);
VecT var = mu2 - mu * mu;
VecT zero = 0.0f;
var = HWY_MAX(var, zero);
var = 1.0f / sqrtf(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OutT>((v - mu) * s * var + b);
}
}
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x,
const WeightT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT bias,
OutT* out,
const size_t size) {
PROFILER_FUNC; PROFILER_FUNC;
// For now we only delegate to the scalar version.
// TODO: implement vectorized version. namespace hn = hwy::HWY_NAMESPACE;
ScalarLayerNorm(x, weight, bias, out, size); const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
double mu, mu2;
ComputeMoments(x, size, mu, mu2);
double var = mu2 - mu * mu;
var = HWY_MAX(var, 0.0);
var = 1.0 / sqrt(var + 1E-6);
const VF vmu = hn::Set(df, static_cast<float>(mu));
const VF vvar = hn::Set(df, static_cast<float>(var));
const VF* HWY_RESTRICT pmu = &vmu;
const VF* HWY_RESTRICT pvar = &vvar;
const auto packed_x = MakeSpan(x, size);
const auto packed_scale = MakeSpan(scale, size);
const auto packed_bias = MakeSpan(bias, size);
const auto packed_out = MakeSpan(out, size);
// Loop body for one vector, called from main loop and remainder loop.
const auto norm = [pmu, pvar](VF x, VF s, VF add) HWY_ATTR -> VF {
const VF centered = hn::Sub(x, *pmu);
const VF mul = hn::Mul(s, *pvar);
return hn::MulAdd(centered, mul, add);
};
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1, s0, s1, add0, add1;
Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_scale, i, s0, s1);
Decompress2(df, packed_bias, i, add0, add1);
const VF n0 = norm(x0, s0, add0);
const VF n1 = norm(x1, s1, add1);
Compress2(df, n0, n1, packed_out, i);
}
}
const size_t remaining = size - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
DecompressAndZeroPad(df, packed_bias, i, buf_bias, remaining);
const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF);
const VF s0 = hn::Load(df, buf_scale);
const VF s1 = hn::Load(df, buf_scale + NF);
const VF add0 = hn::Load(df, buf_bias);
const VF add1 = hn::Load(df, buf_bias + NF);
const VF n0 = norm(x0, s0, add0);
const VF n1 = norm(x1, s1, add1);
Compress2(df, n0, n1, MakeSpan(buf_out, 2 * NF), 0);
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
}
} }
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
@ -447,39 +489,56 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
} }
// Simple loops unless/until batch sizes are large enough to parallelize. // Simple loops unless/until batch sizes are large enough to parallelize.
template <typename WeightT, typename OutT> template <typename XT, typename OT>
void RMSNormBatched(size_t num_tokens, const float* activations, void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
const WeightT* weights, OutT* out, const size_t model_dim) { MatPtrT<OT>& out) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { HWY_DASSERT(weights.Rows() == 1);
RMSNorm(activations + token_idx * model_dim, weights, HWY_DASSERT(weights.Cols() == activations.Cols());
out + token_idx * model_dim, model_dim); HWY_DASSERT(activations.SameShape(out));
}
CallUpcasted(&weights, [&](const auto* weights_t) {
for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
out.Row(token_idx), activations.Cols());
}
});
} }
// TODO: pass RowVectorBatch argument. template <typename XT>
template <typename WeightT, typename InOutT> void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout) {
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, HWY_DASSERT(weights.Rows() == 1);
InOutT* inout, const size_t model_dim) { HWY_DASSERT(weights.Cols() == inout.Cols());
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); CallUpcasted(&weights, [&](const auto* weights_t) {
} for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) {
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
inout.Cols());
}
});
} }
template <typename VecT, typename WeightT, typename OutT> // x and out may be the same.
void LayerNormBatched(size_t num_tokens, const VecT* x, template <typename XT, typename OT>
const WeightT* HWY_RESTRICT weight, void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
const WeightT* HWY_RESTRICT bias, OutT* out, const MatPtr& bias, MatPtrT<OT>& out) {
const size_t size) { HWY_DASSERT(weight.Cols() == bias.Cols());
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { HWY_DASSERT(weight.Cols() == x.Cols());
LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size); HWY_DASSERT(x.SameShape(out));
}
CallUpcastedSame(
&weight, &bias, [&](const auto* weight_t, const auto* bias_t) {
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
LayerNorm(x.Row(token_idx), weight_t->PackedScale1(),
bias_t->PackedScale1(), out.Row(token_idx), x.Cols());
}
});
} }
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other, static HWY_INLINE void AddFromBatched(const MatPtrT<float>& other,
float* x, const size_t model_dim) { MatPtrT<float>& x) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { HWY_DASSERT(x.SameShape(other));
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
model_dim); AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols());
} }
} }
@ -743,8 +802,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
HWY_ASSERT(k != 0); HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size); HWY_ASSERT(k <= vocab_size);
std::vector<double> packed_token_probs; std::vector<double> packed_token_probs;
for (int32_t i = 0; i < vocab_size; ++i) { for (int32_t i = 0; i < static_cast<int32_t>(vocab_size); ++i) {
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) { if (accept_token && !accept_token(i, probabilities[i])) {
continue; continue;
} }
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i])); packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
@ -756,7 +815,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
std::vector<TokenAndProb> token_probs; std::vector<TokenAndProb> token_probs;
token_probs.reserve(k); token_probs.reserve(k);
for (int32_t i = 0; i < k; ++i) { for (int32_t i = 0; i < static_cast<int32_t>(k); ++i) {
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i])); token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
} }
return token_probs; return token_probs;
@ -770,7 +829,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
TopK(probabilities, vocab_size, k, accept_token); TopK(probabilities, vocab_size, k, accept_token);
std::vector<int> topk_indices(k); std::vector<int> topk_indices(k);
std::vector<float> topk_probs(k); std::vector<float> topk_probs(k);
for (int i = 0; i < k; ++i) { for (size_t i = 0; i < k; ++i) {
topk_indices[i] = token_probs[i].token; topk_indices[i] = token_probs[i].token;
topk_probs[i] = token_probs[i].prob; topk_probs[i] = token_probs[i].prob;
} }
@ -788,7 +847,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
TopK(logits, vocab_size, k, accept_token); TopK(logits, vocab_size, k, accept_token);
std::vector<int> topk_indices(k); std::vector<int> topk_indices(k);
std::vector<float> topk_logits(k); std::vector<float> topk_logits(k);
for (int i = 0; i < token_logits.size(); ++i) { for (size_t i = 0; i < token_logits.size(); ++i) {
topk_indices[i] = token_logits[i].token; topk_indices[i] = token_logits[i].token;
topk_logits[i] = token_logits[i].prob; topk_logits[i] = token_logits[i].prob;
} }
@ -807,20 +866,20 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Input has 4096 (64*64) rows, output has 256 (16*16) rows // Input has 4096 (64*64) rows, output has 256 (16*16) rows
// Each output row is the average of a 4x4 block of input rows // Each output row is the average of a 4x4 block of input rows
template <typename T> template <typename T>
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) { MatStorageT<T> AvgPool4x4(MatStorageT<T>& input) {
const Allocator& allocator = ThreadingContext::Get().allocator;
const Extents2D extents = input.Extents(); const Extents2D extents = input.Extents();
// Input validation // Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
// Create output with 256 rows and same number of columns // Create output with 256 rows and same number of columns
const size_t out_rows = 256; // 16 * 16 = 256 output rows const size_t out_rows = 256; // 16 * 16 = 256 output rows
RowVectorBatch<T> result(allocator, Extents2D(out_rows, extents.cols)); MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols),
MatPadding::kOdd);
const size_t input_dim = 64; // Input is 64×64 const size_t input_dim = 64; // Input is 64×64
const size_t output_dim = 16; // Output is 16×16 const size_t output_dim = 16; // Output is 16×16
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) { for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) { for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) {
size_t out_idx = out_row_idx * output_dim + out_col_idx; size_t out_idx = out_row_idx * output_dim + out_col_idx;
T* output_row = result.Batch(out_idx); T* output_row = result.Row(out_idx);
// Initialize output row to zeros // Initialize output row to zeros
std::fill(output_row, output_row + extents.cols, 0); std::fill(output_row, output_row + extents.cols, 0);
// Average 16 row vectors from a 4x4 block // Average 16 row vectors from a 4x4 block
@ -829,9 +888,9 @@ RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
size_t in_row_idx = out_row_idx * 4 + i; size_t in_row_idx = out_row_idx * 4 + i;
size_t in_col_idx = out_col_idx * 4 + j; size_t in_col_idx = out_col_idx * 4 + j;
size_t in_idx = in_row_idx * input_dim + in_col_idx; size_t in_idx = in_row_idx * input_dim + in_col_idx;
const T* input_row = input.Batch(in_idx); const T* input_row = input.Row(in_idx);
// Add each input row to the output // Add each input row to the output
// TODO(philculliton): use AddFrom in ops-inl for a vectorized loop. // TODO(philculliton): use AddFrom in `ops-inl` for a vectorized loop.
for (size_t col = 0; col < extents.cols; ++col) { for (size_t col = 0; col < extents.cols; ++col) {
output_row[col] += input_row[col]; output_row[col] += input_row[col];
} }

View File

@ -20,23 +20,22 @@
#include <cmath> #include <cmath>
#include "util/allocator.h"
#include "util/mat.h" #include "util/mat.h"
#include "hwy/base.h" #include "hwy/base.h"
namespace gcpp { namespace gcpp {
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale( static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
const Allocator& allocator, size_t qkv_dim, bool half_rope, const Allocator& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) { double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2)); MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) { for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const double freq_exponents = const double freq_exponents =
static_cast<double>(2 * dim) / static_cast<double>(rope_dim); static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
// Replacing with expf(ln(1E4) * freq_exponents) changes results // Replacing with expf(ln(1E4) * freq_exponents) changes results
// noticeably. // noticeably.
inv_timescale.Batch(0)[dim] = inv_timescale.Packed()[dim] =
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents)); static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
} }
return inv_timescale; return inv_timescale;

View File

@ -34,7 +34,7 @@
#include "gemma/common.h" // ChooseQueryScale #include "gemma/common.h" // ChooseQueryScale
#include "util/allocator.h" #include "util/allocator.h"
#include "util/basics.h" // BF16 #include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch #include "util/mat.h" // MatStorageT
#include "util/test_util.h" #include "util/test_util.h"
#include "util/threading_context.h" #include "util/threading_context.h"
#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/hwy_gtest.h"
@ -391,7 +391,7 @@ void TestRopeAndMulBy() {
ModelConfig config(Model::GEMMA2_9B, Type::kSFP, ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B)); ChooseWrapping(Model::GEMMA2_9B));
int dim_qkv = config.layer_configs[0].qkv_dim; int dim_qkv = config.layer_configs[0].qkv_dim;
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv)); MatStorageT<float> x("x", dim_qkv);
std::mt19937 gen; std::mt19937 gen;
gen.seed(0x12345678); gen.seed(0x12345678);
@ -399,43 +399,43 @@ void TestRopeAndMulBy() {
auto random_float = [&r, &gen] { return r(gen); }; auto random_float = [&r, &gen] { return r(gen); };
for (int i = 0; i < dim_qkv; ++i) { for (int i = 0; i < dim_qkv; ++i) {
x.All()[i] = random_float(); x.Packed()[i] = random_float();
} }
const float qmul = ChooseQueryScale(config); const float qmul = ChooseQueryScale(config);
const float kmul = 1.0; const float kmul = 1.0;
std::vector<float> qexpected(dim_qkv); MatStorageT<float> qexpected("qexpected", dim_qkv);
std::vector<float> qactual(dim_qkv); MatStorageT<float> qactual("qactual", dim_qkv);
std::vector<float> kexpected(dim_qkv); MatStorageT<float> kexpected("kexpected", dim_qkv);
std::vector<float> kactual(dim_qkv); MatStorageT<float> kactual("kactual", dim_qkv);
RowVectorBatch<float> inv_timescale = CreateInvTimescale( MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim, allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope); config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos. // Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) { for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings // Rope'd Q embeddings
hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv); CopyMat(x, qactual);
hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv); CopyMat(x, qexpected);
ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(), ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv,
pos); inv_timescale.Packed(), pos);
RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos); RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
for (int i = 0; i < dim_qkv; ++i) { for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4) EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual[i]; << "qIndex:" << i << "qInput:" << qactual.Packed()[i];
} }
// Rope'd K embeddings // Rope'd K embeddings
hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv); CopyMat(x, kactual);
hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv); CopyMat(x, kexpected);
ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(), ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv,
pos); inv_timescale.Packed(), pos);
RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos); RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
for (int i = 0; i < dim_qkv; ++i) { for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4) EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual[i]; << "kIndex:" << i << "kInput:" << kactual.Packed()[i];
} }
} }
} }
@ -451,10 +451,9 @@ HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
} }
// Supports bf16 and f32 inputs/outputs, which can be in-place. // Supports bf16 and f32 inputs/outputs, which can be in-place.
template <typename VecT, typename WeightT, typename OutT> template <typename XT, typename WT, typename OT>
HWY_NOINLINE void ScalarRMSNorm(const VecT* x, HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT weight, OutT* out, OT* out, size_t size) {
size_t size) {
constexpr float kEps = 1e-6f; constexpr float kEps = 1e-6f;
float ss = ScalarSquaredL2(x, size); float ss = ScalarSquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
@ -462,32 +461,32 @@ HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
const float v = hwy::ConvertScalarTo<float>(x[j]); const float v = hwy::ConvertScalarTo<float>(x[j]);
const float w = hwy::ConvertScalarTo<float>(weight[j]); const float w = hwy::ConvertScalarTo<float>(weight[j]);
// Note 1.0f centering here // Note 1.0f centering here
out[j] = hwy::ConvertScalarTo<OutT>((1.0f + w) * (ss * v)); out[j] = hwy::ConvertScalarTo<OT>((1.0f + w) * (ss * v));
} }
} }
template <typename VecT, typename WeightT, typename OutT> template <typename XT, typename WT, typename OT>
void TestRMSNorm(hwy::RandomState& rng) { void TestRMSNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128; constexpr size_t kSize = 128;
HWY_ALIGN VecT vec[kSize]; HWY_ALIGN XT vec[kSize];
HWY_ALIGN WeightT weight[kSize]; HWY_ALIGN WT weight[kSize];
HWY_ALIGN OutT expected[kSize]; HWY_ALIGN OT expected[kSize];
HWY_ALIGN OutT actual[kSize]; HWY_ALIGN OT actual[kSize];
for (size_t i = 0; i < kSize; ++i) { for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng)); vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng)); weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
} }
ScalarRMSNorm(vec, weight, expected, kSize); ScalarRMSNorm(vec, weight, expected, kSize);
RMSNorm(vec, weight, actual, kSize); RMSNorm(vec, weight, 0, actual, kSize);
for (size_t i = 0; i < kSize; i++) { for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]); const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]); const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) { if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(), HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WeightT>(), TypeName<OutT>(), i, e, a); TypeName<WT>(), TypeName<OT>(), i, e, a);
} }
} }
} }
@ -526,24 +525,64 @@ void TestLayerNormSimple() {
} }
} }
// Note: there is no vectorized implementation of LayerNorm yet. So this test // Computes mean mu and mean of squares mu2 of a vector. Used in
// currently only checks that the scalar version can be called for the below // ScalarLayerNorm.
// combinations of float/BF16 inputs and outputs. template <typename T>
template <typename VecT, typename WeightT, typename OutT> HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, double& mu,
double& mu2) {
HWY_ASSERT(size > 0);
double sum = 0.0;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f;
sum2 += f * f;
}
mu = sum / size;
mu2 = sum2 / size;
}
// Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void ScalarLayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
const WT* HWY_RESTRICT bias, OT* out,
size_t size) {
constexpr double kEps = 1e-6;
double mu, mu2;
ScalarMus(x, size, mu, mu2);
double var = mu2 - mu * mu;
constexpr double kZero = 0.0;
var = HWY_MAX(var, kZero);
var = 1.0 / sqrt(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OT>((v - mu) * s * var + b);
}
}
template <typename XT, typename WT, typename OT>
void TestLayerNorm(hwy::RandomState& rng) { void TestLayerNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128; constexpr size_t kSize = 128;
VecT vec[kSize]; XT vec[kSize];
WeightT weight[kSize]; WT weight[kSize];
WeightT bias[kSize]; WT bias[kSize];
OutT expected[kSize]; OT expected[kSize];
OutT actual[kSize]; OT actual[kSize];
for (size_t i = 0; i < kSize; ++i) { for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng)); vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng)); weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
bias[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng)); bias[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
} }
double expected_mu, expected_mu2;
ScalarMus(vec, kSize, expected_mu, expected_mu2);
double actual_mu, actual_mu2;
ComputeMoments(vec, kSize, actual_mu, actual_mu2);
ScalarLayerNorm(vec, weight, bias, expected, kSize); ScalarLayerNorm(vec, weight, bias, expected, kSize);
LayerNorm(vec, weight, bias, actual, kSize); LayerNorm(vec, weight, bias, actual, kSize);
@ -551,8 +590,8 @@ void TestLayerNorm(hwy::RandomState& rng) {
const float e = hwy::ConvertScalarTo<float>(expected[i]); const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]); const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) { if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(), HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WeightT>(), TypeName<OutT>(), i, e, a); TypeName<WT>(), TypeName<OT>(), i, e, a);
} }
} }
} }

View File

@ -55,8 +55,7 @@ PYBIND11_MODULE(configs, py_module) {
.value("kSFP", Type::kSFP) .value("kSFP", Type::kSFP)
.value("kNUQ", Type::kNUQ) .value("kNUQ", Type::kNUQ)
.value("kF64", Type::kF64) .value("kF64", Type::kF64)
.value("kC64", Type::kC64) .value("kC64", Type::kC64);
.value("kU128", Type::kU128);
enum_<LayerAttentionType>(py_module, "LayerAttentionType") enum_<LayerAttentionType>(py_module, "LayerAttentionType")
.value("kGemma", LayerAttentionType::kGemma) .value("kGemma", LayerAttentionType::kGemma)

View File

@ -168,9 +168,9 @@ class GemmaModel {
void SetImage(const py::array_t<float, py::array::c_style | void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) { py::array::forcecast>& image) {
const gcpp::Gemma& gemma = *gemma_.GetGemma(); const gcpp::Gemma& gemma = *gemma_.GetGemma();
const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator; const gcpp::ModelConfig& config = gemma.GetModelConfig();
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA && if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) { config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
throw std::invalid_argument("Not a PaliGemma model."); throw std::invalid_argument("Not a PaliGemma model.");
} }
py::buffer_info buffer = image.request(); py::buffer_info buffer = image.request();
@ -182,14 +182,15 @@ class GemmaModel {
float* ptr = static_cast<float*>(buffer.ptr); float* ptr = static_cast<float*>(buffer.ptr);
gcpp::Image c_image; gcpp::Image c_image;
c_image.Set(height, width, ptr); c_image.Set(height, width, ptr);
const size_t image_size = gemma.GetModelConfig().vit_config.image_size; const size_t image_size = config.vit_config.image_size;
c_image.Resize(image_size, image_size); c_image.Resize(image_size, image_size);
image_tokens_ = gcpp::ImageTokens( image_tokens_.reset(new gcpp::ImageTokens(
allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len, "image_tokens",
gemma.GetModelConfig().model_dim)); gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0}; .verbosity = 0};
gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_); gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_);
} }
// Generates a response to the given prompt, using the last set image. // Generates a response to the given prompt, using the last set image.
@ -197,9 +198,7 @@ class GemmaModel {
std::pair<std::string, std::vector<int>> GenerateWithImage( std::pair<std::string, std::vector<int>> GenerateWithImage(
std::string prompt, size_t max_generated_tokens, float temperature, std::string prompt, size_t max_generated_tokens, float temperature,
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) { float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (image_tokens_.Cols() == 0) { if (!image_tokens_) throw std::invalid_argument("No image set.");
throw std::invalid_argument("No image set.");
}
const gcpp::Gemma& model = *gemma_.GetGemma(); const gcpp::Gemma& model = *gemma_.GetGemma();
gemma_.MutableGen().seed(seed); gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig(); gcpp::RuntimeConfig& config = gemma_.MutableConfig();
@ -207,7 +206,7 @@ class GemmaModel {
config.temperature = temperature; config.temperature = temperature;
config.verbosity = 0; config.verbosity = 0;
config.accept_token = accept; config.accept_token = accept;
config.image_tokens = &image_tokens_; config.image_tokens = image_tokens_.get();
std::vector<int> tokens; std::vector<int> tokens;
if (!prompt_tokens.empty()) { if (!prompt_tokens.empty()) {
if (!prompt.empty()) { if (!prompt.empty()) {
@ -219,7 +218,7 @@ class GemmaModel {
} else { } else {
tokens = gemma_.WrapAndTokenize(prompt); tokens = gemma_.WrapAndTokenize(prompt);
} }
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0); tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
size_t num_tokens = tokens.size(); size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens; size_t prefix_end = num_tokens;
config.prefill_tbatch_size = num_tokens; config.prefill_tbatch_size = num_tokens;
@ -252,7 +251,7 @@ class GemmaModel {
private: private:
gcpp::GemmaEnv gemma_; gcpp::GemmaEnv gemma_;
gcpp::ImageTokens image_tokens_; std::unique_ptr<gcpp::ImageTokens> image_tokens_;
float last_prob_; float last_prob_;
}; };

View File

@ -117,11 +117,11 @@ static size_t Stride(const Allocator& allocator, const MatPtr& mat,
} }
} }
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) {
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked; const bool is_nuq = mat.GetType() == Type::kNUQ;
const Allocator& allocator = ThreadingContext::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
const size_t stride = Stride(allocator, mat, padding); const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding);
const size_t num = mat.Rows() * stride; const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
// might not be enough, hence add extra. `MatT` is at least one byte, which // might not be enough, hence add extra. `MatT` is at least one byte, which
// is half of BF16, hence adding `VectorBytes` *elements* is enough. // is half of BF16, hence adding `VectorBytes` *elements* is enough.

View File

@ -28,7 +28,7 @@
#include "compression/shared.h" // Type #include "compression/shared.h" // Type
#include "gemma/tensor_info.h" #include "gemma/tensor_info.h"
#include "io/fields.h" #include "io/fields.h"
#include "util/allocator.h" #include "util/allocator.h" // AlignedPtr2
#include "util/basics.h" // Extents2D #include "util/basics.h" // Extents2D
// IWYU pragma: end_exports // IWYU pragma: end_exports
#include "hwy/base.h" #include "hwy/base.h"
@ -47,7 +47,7 @@ class MatPtr : public IFields {
// `name`: see `SetName`. Note that `stride` is initially `cols` and only // `name`: see `SetName`. Note that `stride` is initially `cols` and only
// differs after deserializing, or calling `SetPtr`. // differs after deserializing, or calling `SetPtr`.
MatPtr(const char* name, Type type, Extents2D extents) MatPtr(const char* name, Type type, Extents2D extents)
: rows_(static_cast<uint32_t>(extents.rows)), : private_rows_(static_cast<uint32_t>(extents.rows)),
cols_(static_cast<uint32_t>(extents.cols)) { cols_(static_cast<uint32_t>(extents.cols)) {
SetName(name); SetName(name);
SetType(type); SetType(type);
@ -74,7 +74,7 @@ class MatPtr : public IFields {
bool HasPtr() const { return ptr_ != nullptr; } bool HasPtr() const { return ptr_ != nullptr; }
// A single row counts as packed because there is no padding between rows. // A single row counts as packed because there is no padding between rows.
bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); } bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); }
const void* Packed() const { const void* Packed() const {
HWY_DASSERT_M(IsPacked(), name_.c_str()); HWY_DASSERT_M(IsPacked(), name_.c_str());
@ -96,17 +96,17 @@ class MatPtr : public IFields {
// Works for any kind of padding. // Works for any kind of padding.
template <typename T> template <typename T>
T* MutableRowT(size_t row) const { T* MutableRowT(size_t row) const {
HWY_DASSERT(row < rows_); HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
} }
template <typename T> template <typename T>
T* RowT(size_t row) { T* RowT(size_t row) {
HWY_DASSERT(row < rows_); HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
} }
template <typename T> template <typename T>
const T* RowT(size_t row) const { const T* RowT(size_t row) const {
HWY_DASSERT(row < rows_); HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_; return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
} }
@ -118,10 +118,22 @@ class MatPtr : public IFields {
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16); HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
} }
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; } size_t Rows() const {
size_t Rows() const { return rows_; } return override_rows_ == 0 ? private_rows_ : override_rows_;
}
size_t Cols() const { return cols_; } size_t Cols() const { return cols_; }
Extents2D Extents() const { return Extents2D(rows_, cols_); } Extents2D Extents() const { return Extents2D(Rows(), cols_); }
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
bool SameShape(const MatPtr& other) const {
return Rows() == other.Rows() && cols_ == other.cols_;
}
// Future calls to `Rows()` during this class' lifetime (not serialized)
// will return this value. Used to set the actual number of rows for
// activations preallocated according to the batch size.
void OverrideRows(size_t rows) {
HWY_ASSERT(rows <= private_rows_);
override_rows_ = static_cast<uint32_t>(rows);
}
// Offset by which to advance pointers to the next row. // Offset by which to advance pointers to the next row.
size_t Stride() const { return stride_; } size_t Stride() const { return stride_; }
@ -150,7 +162,7 @@ class MatPtr : public IFields {
visitor(type_); visitor(type_);
visitor(element_bytes_); visitor(element_bytes_);
visitor(num_elements_); visitor(num_elements_);
visitor(rows_); visitor(private_rows_);
visitor(cols_); visitor(cols_);
visitor(scale_); visitor(scale_);
visitor(stride_); visitor(stride_);
@ -164,11 +176,11 @@ class MatPtr : public IFields {
// padding, which is anyway not supported for NUQ because `compress-inl.h` // padding, which is anyway not supported for NUQ because `compress-inl.h`
// assumes a contiguous stream for its group indexing. // assumes a contiguous stream for its group indexing.
static size_t ComputeNumElements(Type type, Extents2D extents) { static size_t ComputeNumElements(Type type, Extents2D extents) {
const size_t num_elements = extents.Area(); size_t num_elements = extents.Area();
if (type == Type::kNUQ) { if (type == Type::kNUQ) {
// `CompressedArrayElements` is a wrapper function that has the same // `CompressedArrayElements` is a wrapper function that has the same
// effect, but that requires a template argument, not `type`. // effect, but that requires a template argument, not `type`.
return NuqStream::PackedEnd(num_elements); num_elements = NuqStream::PackedEnd(num_elements);
} }
return num_elements; return num_elements;
} }
@ -184,9 +196,10 @@ class MatPtr : public IFields {
// Number of elements to store (including NUQ tables but not padding). // Number of elements to store (including NUQ tables but not padding).
// This a function of `type_` and `Extents()` and stored for compatibility. // This a function of `type_` and `Extents()` and stored for compatibility.
uint32_t num_elements_ = 0; uint32_t num_elements_ = 0;
uint32_t rows_ = 0; uint32_t private_rows_ = 0; // Only access via Rows()! See OverrideRows().
uint32_t cols_ = 0; uint32_t cols_ = 0;
float scale_ = 1.0f; // multiplier for each value, for MatMul.
uint32_t override_rows_ = 0; // not serialized
// Non-owning pointer, must not be freed. The underlying memory must outlive // Non-owning pointer, must not be freed. The underlying memory must outlive
// this object. // this object.
@ -194,6 +207,8 @@ class MatPtr : public IFields {
// Offset by which to advance pointers to the next row, >= `cols_`. // Offset by which to advance pointers to the next row, >= `cols_`.
uint32_t stride_; uint32_t stride_;
float scale_ = 1.0f; // multiplier for each value, for MatMul.
}; };
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides // Non-type erased version of `MatPtr`. Although `MatPtr` also provides
@ -202,6 +217,8 @@ class MatPtr : public IFields {
template <typename MatT> template <typename MatT>
class MatPtrT : public MatPtr { class MatPtrT : public MatPtr {
public: public:
using T = MatT;
// Called by `MatStorageT`. // Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents) MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {} : MatPtr(name, TypeEnum<MatT>(), extents) {}
@ -253,26 +270,67 @@ class MatPtrT : public MatPtr {
}; };
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`. Currently unused but may be used after we move toward // optional `args`. This supports all types used as weights, which excludes
// type-erased `WeightsPtrs`. // `kC64` and `kF64` (used only in `backprop/`).
template <class Func, typename... Args> template <class Func, typename... Args>
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func, decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) { Args&&... args) {
HWY_ASSERT(base != nullptr); if (base->GetType() == Type::kF32) {
if (type == Type::kF32) { return func(dynamic_cast<const MatPtrT<float>*>(base),
return func(dynamic_cast<MatPtrT<float>*>(base),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else if (type == Type::kBF16) { } else if (base->GetType() == Type::kBF16) {
return func(dynamic_cast<MatPtrT<BF16>*>(base), return func(dynamic_cast<const MatPtrT<BF16>*>(base),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else if (type == Type::kSFP) { } else if (base->GetType() == Type::kSFP) {
return func(dynamic_cast<MatPtrT<SfpStream>*>(base), return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else if (type == Type::kNUQ) { } else if (base->GetType() == Type::kNUQ) {
return func(dynamic_cast<MatPtrT<NuqStream>*>(base), return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
std::forward<Args>(args)...); std::forward<Args>(args)...);
} else { } else {
HWY_ABORT("Type %d unknown.", static_cast<int>(type)); HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
}
// Calls `func(base1, base2, args...)`.
template <class Func, typename... Args>
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const Func& func, Args&&... args) {
HWY_ASSERT(base1->GetType() == base2->GetType());
if (base1->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base1),
dynamic_cast<const MatPtrT<float>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base1),
dynamic_cast<const MatPtrT<BF16>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kSFP) {
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
}
}
// Like CallUpcasted, but only for activation types: kBF16 and kF32.
template <class Func, typename... Args>
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
Args&&... args) {
HWY_ASSERT(base != nullptr);
if (base->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base),
std::forward<Args>(args)...);
} else if (base->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
} }
} }
@ -362,8 +420,11 @@ class MatStorageT : public MatPtrT<MatT> {
public: public:
MatStorageT(const char* name, Extents2D extents, MatPadding padding) MatStorageT(const char* name, Extents2D extents, MatPadding padding)
: MatPtrT<MatT>(name, extents) { : MatPtrT<MatT>(name, extents) {
owner_.AllocateFor(*this, padding); if (extents.Area() != 0) owner_.AllocateFor(*this, padding);
} }
// Shorthand for 1D tensors: packing does not help, hence `kPacked`.
MatStorageT(const char* name, size_t cols)
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {}
~MatStorageT() = default; ~MatStorageT() = default;
// Allow move for backprop/activations. // Allow move for backprop/activations.
@ -467,81 +528,14 @@ using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>; using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>; using RowPtrD = RowPtr<double>;
// Owns dynamically-allocated aligned memory for a batch of row vectors. // TODO: remove allocator arg once kCyclic is removed.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory. Unlike `MatPtr`, this lacks metadata.
// TODO: replace with `MatStorageT`.
template <typename T> template <typename T>
class RowVectorBatch { RowPtr<T> RowPtrFromMat(const Allocator& allocator,
public: const MatPtrT<T>& row_vectors) {
// Default ctor for Activations ctor. // RowPtr is non-const for MatMul C, but is also used for A which is const.
RowVectorBatch() = default; // Callers are responsible for checking their usage of RowPtr.
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default, return RowPtr<T>(allocator, const_cast<T*>(row_vectors.Row(0)),
// we default to tightly packed rows (`stride = cols`). row_vectors.Cols(), row_vectors.Stride());
// WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(const Allocator& allocator, Extents2D extents,
size_t stride = 0)
: extents_(extents) {
if (stride == 0) {
stride_ = extents_.cols;
} else {
HWY_ASSERT(stride >= extents_.cols);
stride_ = stride;
}
// Allow binding the entire matrix.
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
allocator.QuantumBytes() / sizeof(T));
mem_ = allocator.Alloc<T>(padded);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
size_t Stride() const { return stride_; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
private:
AlignedPtr2<T[]> mem_;
Extents2D extents_;
size_t stride_;
};
template <typename T>
RowPtr<T> RowPtrFromBatch(const Allocator& allocator,
RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
row_vectors.Stride());
}
template <typename T>
RowVectorBatch<T> AllocateAlignedRows(const Allocator& allocator,
Extents2D extents) {
return RowVectorBatch<T>(
allocator, extents,
StrideForCyclicOffsets(extents.cols,
allocator.QuantumBytes() / sizeof(T)));
} }
} // namespace gcpp } // namespace gcpp