Replace RowVectorBatch with MatStorageT

KVCache: add ctor required for MatStorageT, remove Create; bf_pre_ffw_rms_out -> pre_ffw_rms_out
optimize_test: larger vocab_size requires more steps
shared.h: Remove unused u128 type
correctly set Activation matrix rows, avoid passing as arg
ops: pass Mat instead of pointers/sizes; vectorize LayerNorm; support any weight type
mat: add OverrideRows, used by SetBatchSize
PiperOrigin-RevId: 757790736
This commit is contained in:
Jan Wassenberg 2025-05-12 09:15:03 -07:00 committed by Copybara-Service
parent cf7dd80c17
commit 45ad847a41
39 changed files with 949 additions and 917 deletions

View File

@ -415,6 +415,7 @@ cc_library(
hdrs = ["gemma/kv_cache.h"],
deps = [
":configs",
":mat",
"@highway//:hwy",
],
)
@ -425,6 +426,7 @@ cc_library(
deps = [
":args",
":basics",
":mat",
":ops", # matmul.h
"//io",
"@highway//:hwy",

View File

@ -38,8 +38,8 @@ struct ForwardLayer {
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
attention_out(
MakePacked<T>("attention_out", seq_len, config.model_dim)),
bf_pre_ffw_rms_out(
MakePacked<T>("bf_preFF_rms_out", seq_len, config.model_dim)),
pre_ffw_rms_out(
MakePacked<T>("preFF_rms_out", seq_len, config.model_dim)),
ffw_hidden(
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
ffw_hidden_gated(
@ -53,7 +53,7 @@ struct ForwardLayer {
MatStorageT<T> att_out;
MatStorageT<T> att_post1;
MatStorageT<T> attention_out;
MatStorageT<T> bf_pre_ffw_rms_out;
MatStorageT<T> pre_ffw_rms_out;
MatStorageT<T> ffw_hidden;
MatStorageT<T> ffw_hidden_gated;
const LayerConfig& layer_config;

View File

@ -170,8 +170,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
const ForwardLayer<float>& forward,
const float* HWY_RESTRICT next_layer_grad, size_t num_tokens,
LayerWeightsPtrs<T>& grad, ForwardLayer<float>& backward,
const RowVectorBatch<float>& inv_timescale,
hwy::ThreadPool& pool) {
const MatStorageT<float>& inv_timescale, hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim;
const size_t qkv_dim = config.qkv_dim;
@ -207,15 +206,14 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
}
}
MatMulVJP(weights.gating_einsum_w.Packed(),
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
model_dim, ff_hidden_dim * 2, num_tokens,
grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(),
pool);
RMSNormVJP(
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens,
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool);
MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2,
num_tokens, grad.gating_einsum_w.Packed(),
backward.pre_ffw_rms_out.Packed(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.Packed(),
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(),
backward.attention_out.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * model_dim,
@ -275,7 +273,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos);
}
for (size_t head = 0; head < heads; ++head) {
@ -283,7 +281,7 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
float* HWY_RESTRICT b_q =
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
MulByConst(query_scale, b_q, qkv_dim);
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos);
}
}
@ -342,7 +340,7 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
const ForwardPass<float>& forward,
ModelWeightsPtrs<T>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config;
const size_t kVocabSize = config.vocab_size;

View File

@ -42,7 +42,7 @@ void CrossEntropyLossBackwardPassT(const Prompt& prompt,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward,
inv_timescale, pool);
@ -62,7 +62,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)(
prompt, weights, forward, grad, backward, inv_timescale, pool);

View File

@ -29,7 +29,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const ForwardPass<float>& forward,
ModelWeightsPtrs<float>& grad,
ForwardPass<float>& backward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp

View File

@ -218,16 +218,15 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.Packed(),
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
grad.gating_einsum_w.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(),
backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(),
backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
num_tokens);
RMSNormVJPT(
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(),
backward.attention_out.Packed(), model_dim, num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(),
forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(),
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(),
model_dim, num_tokens);
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);

View File

@ -202,7 +202,7 @@ void TestEndToEnd() {
ReverseSequenceSampler training_task({0, 0, 1, 1});
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
MatStorageT<float> inv_timescale = CreateInvTimescale(
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) {

View File

@ -74,7 +74,7 @@ void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x,
hwy::ThreadPool& pool) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t offset = pos * model_dim;
RMSNorm(x + offset, weights, output + offset, model_dim);
RMSNorm(x + offset, weights, 0, output + offset, model_dim);
}
}
@ -100,7 +100,7 @@ template <typename T>
void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ForwardLayer<float>& activations, size_t num_tokens,
float* HWY_RESTRICT output,
const RowVectorBatch<float>& inv_timescale,
const MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
const LayerConfig& config = weights.layer_config;
const size_t model_dim = config.model_dim;
@ -125,14 +125,14 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, inv_timescale.Const(), pos);
Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, inv_timescale.Const(), pos);
Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos);
MulByConst(query_scale, q, kQKVDim);
});
@ -194,11 +194,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(), model_dim, num_tokens,
activations.bf_pre_ffw_rms_out.Packed(), pool);
activations.pre_ffw_rms_out.Packed(), pool);
const size_t kFFHiddenDim = config.ff_hidden_dim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim,
activations.pre_ffw_rms_out.Packed() + pos * model_dim,
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
size_t context_size,
const ModelWeightsPtrs<T>& weights,
ForwardPass<float>& forward,
const RowVectorBatch<float>& inv_timescale,
const MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
const ModelConfig& config = weights.weights_config;
const size_t vocab_size = config.vocab_size;

View File

@ -38,7 +38,7 @@ namespace HWY_NAMESPACE {
float CrossEntropyLossForwardPassT(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size,
weights, forward, inv_timescale, pool);
@ -56,7 +56,7 @@ HWY_EXPORT(CrossEntropyLossForwardPassT);
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool) {
return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)(
prompt, weights, forward, inv_timescale, pool);

View File

@ -19,7 +19,7 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
@ -27,7 +27,7 @@ namespace gcpp {
float CrossEntropyLossForwardPass(const Prompt& prompt,
const ModelWeightsPtrs<float>& weights,
ForwardPass<float>& forward,
RowVectorBatch<float>& inv_timescale,
MatStorageT<float>& inv_timescale,
hwy::ThreadPool& pool);
} // namespace gcpp

View File

@ -219,12 +219,11 @@ void ApplyLayer(const LayerWeightsPtrs<T>& weights,
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(),
activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens);
activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens);
MatMulT(weights.gating_einsum_w.Packed(),
activations.bf_pre_ffw_rms_out.Packed(),
activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim,
num_tokens);
activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(),
ff_hidden_dim * 2, model_dim, num_tokens);
GatedGelu(activations.ffw_hidden.Packed(),
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);

View File

@ -62,9 +62,9 @@ TEST(OptimizeTest, GradientDescent) {
grad_m.ZeroInit();
grad_v.ZeroInit();
ForwardPass<float> forward(config), backward(config);
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
KVCache kv_cache(config, /*prefill_tbatch_size=*/16);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
@ -147,7 +147,7 @@ TEST(OptimizeTest, GradientDescent) {
printf("Num steps: %zu\n", steps);
printf("Final weights:\n");
gemma.MutableWeights().LogWeightStatsF32();
EXPECT_LT(steps, 50);
EXPECT_LT(steps, 80);
EXPECT_EQ(num_ok, kBatchSize);
}

View File

@ -23,6 +23,7 @@
#include <stdio.h>
#include <algorithm> // std::shuffle
#include <array>
#include <random>
#include "compression/distortion.h"
@ -104,7 +105,7 @@ struct TestPlateaus {
HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f);
}
std::random_device rd;
std::random_device rd; // NOLINT
std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng);
@ -151,7 +152,7 @@ struct TestRamp {
HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f);
}
std::random_device rd;
std::random_device rd; // NOLINT
std::mt19937 rng(rd());
std::shuffle(in.get(), in.get() + kGroupSize, rng);
@ -246,7 +247,8 @@ struct TestOffset {
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
@ -296,7 +298,8 @@ struct TestUnalignedOffset {
auto in = hwy::AllocateAligned<float>(total); // Enc() requires f32
auto dec1 = hwy::AllocateAligned<T>(total);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
auto dec2 = hwy::AllocateAligned<T>(num_decompressed);
HWY_ASSERT(in && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
@ -347,7 +350,8 @@ struct TestDec2 {
auto dec0 = hwy::AllocateAligned<T>(total);
auto dec1 = hwy::AllocateAligned<T>(total);
auto dec2 = hwy::AllocateAligned<T>(kMidLen);
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(total));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes()));
HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq);
const auto nuq_span = MakeSpan(nuq.get(), total);
@ -449,7 +453,8 @@ struct TestEncDec {
const size_t num = 4 * kGroupSize;
auto in = hwy::AllocateAligned<float>(num); // Enc() requires f32
auto out = hwy::AllocateAligned<T>(num); // already padded
auto nuq = hwy::AllocateAligned<NuqStream>(NuqStream::PackedEnd(num));
auto nuq = hwy::AllocateAligned<NuqStream>(
hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes()));
HWY_ASSERT(in && out && nuq);
const auto nuq_span = MakeSpan(nuq.get(), num);

View File

@ -164,11 +164,11 @@ constexpr bool IsNuqStream() {
// weights for a model, but can be used for other purposes, such as types for
// `WeightsPtrs`. When adding a new type that is supported, also
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 };
// These are used in `ModelConfig.Specifier`, hence the strings will not
// change, though new ones may be added.
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"};
static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "c64"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {0,
@ -177,8 +177,7 @@ static constexpr size_t kTypeBits[] = {0,
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
8 * sizeof(std::complex<double>),
8 * sizeof(hwy::uint128_t)};
8 * sizeof(std::complex<double>)};
static inline bool EnumValid(Type type) {
return static_cast<size_t>(type) < kNumTypes;
@ -200,8 +199,6 @@ Type TypeEnum() {
return Type::kF64;
} else if constexpr (hwy::IsSame<Packed, std::complex<double>>()) {
return Type::kC64;
} else if constexpr (hwy::IsSame<Packed, hwy::uint128_t>()) {
return Type::kU128;
} else {
HWY_DASSERT(false);
return Type::kUnknown;

View File

@ -73,8 +73,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(),
env.MutableConfig().prefill_tbatch_size);
KVCache kv_cache(env.GetGemma()->GetModelConfig(),
env.MutableConfig().prefill_tbatch_size);
float entropy = ComputeCrossEntropy(
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy;

View File

@ -52,9 +52,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader,
const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) {
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1);
kv_caches_[0] =
KVCache::Create(gemma_.GetModelConfig(), inference.prefill_tbatch_size);
kv_caches_.push_back(
KVCache(gemma_.GetModelConfig(), inference.prefill_tbatch_size));
InitGenerator(inference, gen_);
@ -131,15 +130,10 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
runtime_config_.decode_qbatch_size);
}
// Ensure we have one KVCache per query.
if (kv_caches_.size() < num_queries) {
kv_caches_.resize(num_queries);
}
for (size_t i = 1; i < num_queries; ++i) {
if (kv_caches_[i].seq_len == 0) {
kv_caches_[i] = KVCache::Create(gemma_.GetModelConfig(),
runtime_config_.prefill_tbatch_size);
}
// Ensure we have at least one KVCache per query.
while (kv_caches_.size() < num_queries) {
kv_caches_.push_back(
KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size));
}
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};

View File

@ -53,8 +53,7 @@ int main(int argc, char** argv) {
// Instantiate model and KV Cache
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
gcpp::Gemma gemma(loader, env);
gcpp::KVCache kv_cache = gcpp::KVCache::Create(gemma.GetModelConfig(),
inference.prefill_tbatch_size);
gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
size_t generated = 0;
// Initialize random number generator

View File

@ -39,11 +39,8 @@ class SimplifiedGemma {
threading_(threading),
inference_(inference),
env_(MakeMatMulEnv(threading_)),
gemma_(loader_, env_) {
// Instantiate model and KV Cache
kv_cache_ = gcpp::KVCache::Create(gemma_.GetModelConfig(),
inference_.prefill_tbatch_size);
gemma_(loader_, env_),
kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) {
// Initialize random number generator
std::random_device rd;
gen_.seed(rd());

View File

@ -23,106 +23,127 @@
#include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch
#include "util/mat.h" // MatStorageT
namespace gcpp {
struct Activations {
explicit Activations(const ModelConfig& config)
Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env)
: weights_config(config),
layer_config(config.layer_configs[0]),
seq_len(config.seq_len),
cache_pos_size(config.CachePosSize()) {}
cache_pos_size(config.CachePosSize()),
is_griffin(layer_config.type ==
LayerAttentionType::kGriffinRecurrentBlock),
RowVectorBatch<float> x; // input
RowVectorBatch<float> q; // query, also KV if MHA.
RowVectorBatch<float> logits;
x("x", Extents2D(batch_size, config.model_dim), pad_),
q("q",
Extents2D(batch_size, layer_config.heads * layer_config.QStride()),
pad_),
logits("logits", Extents2D(batch_size, config.vocab_size), pad_),
// Attention
RowVectorBatch<float> pre_att_rms_out;
RowVectorBatch<float> att; // attention vector
RowVectorBatch<float> att_out; // attention output
// Accumulation of attention outputs over heads
RowVectorBatch<float> att_sums;
pre_att_rms_out("pre_att_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
att("att", Extents2D(batch_size, layer_config.heads * config.seq_len),
pad_),
att_out(
"att_out",
Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim),
pad_),
att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_),
// Gated FFW
RowVectorBatch<BF16> bf_pre_ffw_rms_out;
RowVectorBatch<float> C1;
RowVectorBatch<float> C2;
RowVectorBatch<float> ffw_out;
pre_ffw_rms_out("pre_ffw_rms_out",
Extents2D(batch_size, config.model_dim), pad_),
C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_),
ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_),
// Griffin
RowVectorBatch<float> griffin_x;
RowVectorBatch<float> griffin_y;
RowVectorBatch<float> griffin_gate_x;
RowVectorBatch<float> griffin_multiplier;
// No padding for Griffin because it does not always use Row().
griffin_x("griffin_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
griffin_y("griffin_y",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
griffin_gate_x(
"griffin_gate_x",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
griffin_multiplier(
"griffin_mul",
is_griffin ? Extents2D(batch_size, config.model_dim) : none_,
MatPadding::kPacked),
// Rope
RowVectorBatch<float> inv_timescale;
RowVectorBatch<float> inv_timescale_global;
inv_timescale(
CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope)),
inv_timescale_global(CreateInvTimescale(
env->ctx.allocator, layer_config.qkv_dim,
layer_config.post_qk == PostQKType::HalfRope, 1000000.0)),
// Dynamic because no default ctor and only initialized in `Allocate`.
MatMulEnv* env;
env(env) {
HWY_ASSERT(batch_size != 0);
}
void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size);
q.OverrideRows(batch_size);
logits.OverrideRows(batch_size);
pre_att_rms_out.OverrideRows(batch_size);
att.OverrideRows(batch_size);
att_out.OverrideRows(batch_size);
att_sums.OverrideRows(batch_size);
pre_ffw_rms_out.OverrideRows(batch_size);
C1.OverrideRows(batch_size);
C2.OverrideRows(batch_size);
ffw_out.OverrideRows(batch_size);
if (is_griffin) {
griffin_x.OverrideRows(batch_size);
griffin_y.OverrideRows(batch_size);
griffin_gate_x.OverrideRows(batch_size);
griffin_multiplier.OverrideRows(batch_size);
}
}
PostQKType post_qk = PostQKType::Rope;
// And the config.
const ModelConfig& weights_config;
const LayerConfig& layer_config;
size_t seq_len;
size_t cache_pos_size = 0;
size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT.
bool is_griffin = false;
const Extents2D none_ = Extents2D();
const MatPadding pad_ = MatPadding::kOdd;
void Allocate(size_t batch_size, MatMulEnv* env) {
const Allocator& allocator = env->ctx.allocator;
MatStorageT<float> x; // input
MatStorageT<float> q; // query, also KV if MHA.
MatStorageT<float> logits;
post_qk = layer_config.post_qk;
const size_t model_dim = weights_config.model_dim;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
const size_t vocab_size = weights_config.vocab_size;
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;
// Attention
MatStorageT<float> pre_att_rms_out;
MatStorageT<float> att; // attention vector
MatStorageT<float> att_out; // attention output
// Accumulation of attention outputs over heads
MatStorageT<float> att_sums;
x = RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>(
allocator, Extents2D(batch_size, heads * layer_config.QStride()));
if (vocab_size > 0) {
logits =
RowVectorBatch<float>(allocator, Extents2D(batch_size, vocab_size));
}
// Gated FFW
MatStorageT<BF16> pre_ffw_rms_out;
MatStorageT<float> C1;
MatStorageT<float> C2;
MatStorageT<float> ffw_out;
pre_att_rms_out =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(
allocator, Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(allocator,
Extents2D(batch_size, heads * qkv_dim));
att_sums =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
// Griffin
MatStorageT<float> griffin_x;
MatStorageT<float> griffin_y;
MatStorageT<float> griffin_gate_x;
MatStorageT<float> griffin_multiplier;
bf_pre_ffw_rms_out =
RowVectorBatch<BF16>(allocator, Extents2D(batch_size, model_dim));
C1 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
C2 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
ffw_out =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
// Rope
MatStorageT<float> inv_timescale;
MatStorageT<float> inv_timescale_global;
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
griffin_x =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_y =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_gate_x =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_multiplier =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
}
inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim,
post_qk == PostQKType::HalfRope);
inv_timescale_global = CreateInvTimescale(
allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
this->env = env;
}
MatMulEnv* env;
};
} // namespace gcpp

View File

@ -46,8 +46,7 @@ ConversationData::ConversationData(const ModelConfig& model_config,
size_t prefill_tbatch_size)
: model_config_ref_(model_config),
prefill_tbatch_size_(prefill_tbatch_size),
kv_cache(std::make_unique<KVCache>(
KVCache::Create(model_config, prefill_tbatch_size))),
kv_cache(std::make_unique<KVCache>(model_config, prefill_tbatch_size)),
abs_pos(0) {}
// ConversationData copy constructor implementation
@ -184,25 +183,28 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
inference_args.CopyTo(runtime_config);
size_t prefix_end = 0;
const ModelConfig& model_config = model.GetModelConfig();
// generate
std::vector<int> prompt;
ImageTokens image_tokens;
const size_t pool_dim = model_config.vit_config.pool_dim;
ImageTokens image_tokens(
"image_tokens",
image_data
? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim),
model_config.model_dim)
: Extents2D(0, 0),
MatPadding::kOdd);
if (image_data != nullptr) {
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
image_tokens =
ImageTokens(model.Env().ctx.allocator,
Extents2D(model.GetModelConfig().vit_config.seq_len /
(pool_dim * pool_dim),
model.GetModelConfig().model_dim));
HWY_ASSERT(model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
model.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA ||
model_config.wrapping == PromptWrapping::GEMMA_VLM);
Image image;
image.Set(image_width, image_height, static_cast<const float*>(image_data));
// We may need to resize the supplied image depending on whether we're using
// PaliGemma or Gemma 3.
const size_t image_size = model.GetModelConfig().vit_config.image_size;
const size_t image_size = model_config.vit_config.image_size;
image.Resize(image_size, image_size);
// Use the existing runtime_config defined earlier in the function.
@ -217,10 +219,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
ss << static_cast<int>(image_tokens_duration * 1000) << " ms\n",
LogDebug(ss.str().c_str());
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.GetModelConfig().wrapping,
active_conversation->abs_pos, prompt_string,
image_tokens.BatchSize());
prompt = WrapAndTokenize(
model.Tokenizer(), model.ChatTemplate(), model_config.wrapping,
active_conversation->abs_pos, prompt_string, image_tokens.Rows());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
@ -230,7 +231,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// Text-only case (original logic)
// Use abs_pos from the active conversation
prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(),
model.GetModelConfig().wrapping,
model_config.wrapping,
active_conversation->abs_pos, prompt_string);
prompt_size = prompt.size();
}
@ -251,7 +252,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string,
// prepare for next turn
if (!inference_args.multiturn ||
model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
model_config.wrapping == PromptWrapping::PALIGEMMA) {
// If not multiturn, or Paligemma (which handles turns differently),
// reset the *active* conversation's position.
active_conversation->abs_pos = 0;

View File

@ -188,8 +188,8 @@ class GemmaContext {
// rewind to initial state.
active_conversation->abs_pos = 0;
// Replace the cache within the current ConversationData object
active_conversation->kv_cache = std::make_unique<KVCache>(KVCache::Create(
model.GetModelConfig(), inference_args.prefill_tbatch_size));
active_conversation->kv_cache = std::make_unique<KVCache>(
model.GetModelConfig(), inference_args.prefill_tbatch_size);
LogDebug((log_prefix + "Successfully rewound to initial state.").c_str());
} else {

View File

@ -89,11 +89,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
activations.pre_att_rms_out.Batch(batch_idx),
activations.pre_att_rms_out.Row(batch_idx),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
@ -103,17 +103,16 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// Conv1D.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
HWY_FULL(float) df;
HWY_DASSERT(model_dim % hn::Lanes(df) == 0);
const size_t layer_offset = layer * model_dim * (conv_1d_width - 1);
// cache[i] = input at time t-i.
float* HWY_RESTRICT cache[kMaxConv1DWidth];
cache[0] = x;
for (size_t i = 1; i < conv_1d_width; i++) {
cache[i] =
kv_cache.conv1d_cache.get() + layer_offset +
kv_cache.conv1d_cache.Row(layer) +
((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim;
}
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
@ -140,12 +139,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// RGLRU
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx);
float* HWY_RESTRICT rnn_state =
kv_cache.rglru_cache.get() + layer * model_dim;
float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx);
float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx);
float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer);
pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
const size_t kHeadDim = model_dim / heads;
@ -193,8 +191,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
// Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* out_ptr = activations.att_sums.Batch(batch_idx);
float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx);
float* out_ptr = activations.att_sums.Row(batch_idx);
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
pool);
@ -217,7 +215,7 @@ class GemmaAttention {
const float mul) {
// qk is either q or k, so qkv_dim is the length we operate on.
const size_t qkv_dim = layer_config_.qkv_dim;
const float* inv_timescale = activations_.inv_timescale.Const();
const float* inv_timescale = activations_.inv_timescale.Packed();
bool is_global_layer =
activations_.weights_config.attention_window_sizes[layer] ==
activations_.seq_len;
@ -227,7 +225,7 @@ class GemmaAttention {
activations_.weights_config.model == Model::GEMMA3_12B ||
activations_.weights_config.model == Model::GEMMA3_27B ||
activations_.weights_config.model == Model::GEMMA3_1B)) {
inv_timescale = activations_.inv_timescale_global.Const();
inv_timescale = activations_.inv_timescale_global.Packed();
}
// PostQKType::Rope
(void)layer;
@ -249,11 +247,10 @@ class GemmaAttention {
const size_t heads = layer_config_.heads;
const size_t kv_heads = layer_config_.kv_heads;
const auto pre_att_rms_out =
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr()
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T;
ConstMat<WeightT> w_q1(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w1);
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
// model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows.
// We must shrink to the actual size because MatMul verifies
@ -262,20 +259,19 @@ class GemmaAttention {
// computed in the second MatMul.
const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows);
MatMul(pre_att_rms_out, w_q1,
MatMul(activations_.pre_att_rms_out, w_q1,
/*add=*/nullptr, *activations_.env,
RowPtrFromBatch(allocator_, activations_.q));
RowPtrFromMat(allocator_, activations_.q));
if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else {
decltype(w_q1) w_q2;
decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr()
? layer_weights_.qkv_einsum_w
: layer_weights_.qkv_einsum_w2);
if (layer_weights_.qkv_einsum_w.HasPtr()) {
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w);
// Skip first half of the matrix.
w_q2.ofs = w_q2.Row(w1_rows);
} else {
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
}
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
@ -290,13 +286,13 @@ class GemmaAttention {
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMul(pre_att_rms_out, w_q2,
MatMul(activations_.pre_att_rms_out, w_q2,
/*add=*/nullptr, *activations_.env, kv_rows);
} else {
// Proceed row by row because there will be wraparound.
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx);
const float* x = activations_.pre_att_rms_out.Row(interleaved_idx);
const size_t query_idx = interleaved_idx % num_queries_;
const size_t batch_idx = interleaved_idx / num_queries_;
KVCache& kv_cache = kv_caches_[query_idx];
@ -327,15 +323,15 @@ class GemmaAttention {
// If MHA, copy computed K and V into KVCache.
if (is_mha_) {
const float* HWY_RESTRICT mha_kv =
activations_.q.Batch(interleaved_idx) + head * q_stride_ +
activations_.q.Row(interleaved_idx) + head * q_stride_ +
qkv_dim;
hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv));
}
// Apply further processing to K.
if (layer_weights_.key_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv,
qkv_dim);
RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(),
0, kv, qkv_dim);
}
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
});
@ -402,7 +398,8 @@ class GemmaAttention {
// Apply rope and scaling to Q.
if (layer_weights_.query_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim);
RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q,
qkv_dim);
}
PositionalEncodingQK(q, pos, layer_, query_scale);
@ -435,13 +432,12 @@ class GemmaAttention {
const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2;
float* HWY_RESTRICT q =
activations_.q.Batch(interleaved_idx) + head * q_stride_;
activations_.q.Row(interleaved_idx) + head * q_stride_;
float* HWY_RESTRICT att =
activations_.att.Batch(interleaved_idx) +
activations_.att.Row(interleaved_idx) +
head * activations_.seq_len;
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(interleaved_idx) +
head * qkv_dim;
activations_.att_out.Row(interleaved_idx) + head * qkv_dim;
// Make strided views into the kv cache entries for the current
// query and head.
@ -476,28 +472,25 @@ class GemmaAttention {
private:
// Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and
// head_dim (`qkv_dim`) into output (`layer_out`).
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.Attention.SumHeads");
// att_weights and att_out are concatenated heads, each of length
// layer_config_.qkv_dim. Thus the [num_interleaved,
// layer_config_.model_dim] matmul output is the sum over heads. Compare
// gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD',
// encoded)
HWY_DASSERT(layer_config_.model_dim > 0);
HWY_DASSERT(layer_config_.heads > 0);
HWY_DASSERT(layer_config_.qkv_dim > 0);
HWY_DASSERT(layer_config_.model_dim != 0 && layer_config_.heads != 0 &&
layer_config_.qkv_dim != 0);
HWY_DASSERT(layer_weights_.att_weights.HasPtr());
HWY_DASSERT(activations_.att_out.All() != nullptr);
HWY_DASSERT(activations_.att_sums.All() != nullptr);
HWY_DASSERT(activations_.att_out.HasPtr());
HWY_DASSERT(activations_.att_sums.HasPtr());
const float* add =
layer_weights_.layer_config.softmax_attn_output_biases
? layer_weights_.attention_output_biases.PackedScale1()
: nullptr;
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
ConstMatFromWeights(layer_weights_.att_weights), add,
*activations_.env,
RowPtrFromBatch(allocator_, activations_.att_sums));
MatMul(activations_.att_out, layer_weights_.att_weights, add,
*activations_.env, RowPtrFromMat(allocator_, activations_.att_sums));
}
public:
@ -524,7 +517,7 @@ class GemmaAttention {
const size_t num_interleaved = num_tokens_ * num_queries_;
ComputeQKV(num_interleaved);
DotSoftmaxWeightedSum(num_interleaved);
SumHeads(num_interleaved);
SumHeads();
}
private:
@ -618,12 +611,11 @@ class VitAttention {
HWY_NOINLINE void ComputeQKV() {
PROFILER_ZONE("Gen.VitAttention.QKV");
auto& qkv = activations_.q;
HWY_ASSERT(qkv.BatchSize() == num_tokens_);
HWY_ASSERT(qkv.Rows() == num_tokens_);
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w,
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
RowPtrFromBatch(allocator_, qkv));
RowPtrFromMat(allocator_, qkv));
}
// TODO(philculliton): transition fully to MatMul.
@ -635,52 +627,49 @@ class VitAttention {
const float query_scale = 1.0f / sqrtf(static_cast<float>(qkv_dim));
PROFILER_ZONE("Gen.VitAttention.DotSoftmax");
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
RowVectorBatch<float> Q =
AllocateAlignedRows<float>(allocator_, Extents2D(num_tokens_, qkv_dim));
RowVectorBatch<float> K =
AllocateAlignedRows<float>(allocator_, Extents2D(seq_len, qkv_dim));
RowVectorBatch<float> C(allocator_, Extents2D(num_tokens_, seq_len));
// Shift Q, K, VT to MatStorageT.
MatStorageT<float> Q("Q2", Extents2D(num_tokens_, qkv_dim),
MatPadding::kPacked);
MatStorageT<float> K("K2", Extents2D(seq_len, qkv_dim),
MatPadding::kPacked);
MatStorageT<float> C("C2", Extents2D(num_tokens_, seq_len),
MatPadding::kPacked);
// Initialize att_out to zero prior to head loop.
hwy::ZeroBytes(activations_.att_out.All(),
num_tokens_ * heads * qkv_dim * sizeof(float));
ZeroInit(activations_.att_out);
for (size_t head = 0; head < heads; ++head) {
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t token = task;
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim;
float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim;
// TODO: shift to MatMul with A.scale once MatMul is confirmed working
MulByConst(query_scale, q, qkv_dim);
hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float));
hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float));
});
pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
const size_t seq_idx = task;
float* HWY_RESTRICT k =
activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float));
activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim;
hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float));
});
// this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
RowPtrFromBatch(allocator_, C));
MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C));
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Batch(task);
float* HWY_RESTRICT c = C.Row(task);
Softmax(c, C.Cols());
});
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
size_t token = task;
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim;
activations_.att_out.Row(token) + head * qkv_dim;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v =
activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim);
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim);
}
});
}
@ -701,24 +690,24 @@ class VitAttention {
const size_t token = task / layer_config_.heads;
// Compute Q.K scores, which are "logits" stored in head_att.
float* HWY_RESTRICT q =
activations_.q.Batch(token) + head * 3 * qkv_dim;
activations_.q.Row(token) + head * 3 * qkv_dim;
MulByConst(query_scale, q, qkv_dim);
float* HWY_RESTRICT head_att =
activations_.att.Batch(token) + head * activations_.seq_len;
activations_.att.Row(token) + head * activations_.seq_len;
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT k =
activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim;
activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim;
head_att[i] = Dot(q, k, qkv_dim); // score = q.k
}
// SoftMax yields "probabilities" in head_att.
Softmax(head_att, seq_len);
// Compute weighted sum of v into att_out.
float* HWY_RESTRICT att_out =
activations_.att_out.Batch(token) + head * qkv_dim;
activations_.att_out.Row(token) + head * qkv_dim;
hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out));
for (size_t i = 0; i < seq_len; ++i) {
float* HWY_RESTRICT v = activations_.q.Batch(i) +
head * 3 * qkv_dim + 2 * qkv_dim;
float* HWY_RESTRICT v =
activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim;
MulByConstAndAdd(head_att[i], v, att_out, qkv_dim);
}
});
@ -732,10 +721,9 @@ class VitAttention {
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums);
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums);
MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias,
*activations_.env, att_sums);
}
public:
@ -771,7 +759,7 @@ class VitAttention {
template <typename T>
HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
T* HWY_RESTRICT c2, size_t count) {
const T* HWY_RESTRICT c2, size_t count) {
PROFILER_ZONE("Gen.Activation");
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<T>;
@ -787,12 +775,38 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1,
});
}
// No C2 multiplier.
template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1) {
using T = typename Mat::T;
for (size_t i = 0; i < c1.Rows(); ++i) {
// Cast to correct type so type deduction works.
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
c1.Cols());
}
}
template <class Mat>
void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) {
using T = typename Mat::T;
HWY_DASSERT(c1.SameShape(*c2));
if (c2 && c2->HasPtr()) {
for (size_t i = 0; i < c1.Rows(); ++i) {
Activation(activation, c1.Row(i), c2->Row(i), c1.Cols());
}
} else { // No multiplier
for (size_t i = 0; i < c1.Rows(); ++i) {
Activation(activation, c1.Row(i), static_cast<const T*>(nullptr),
c1.Cols());
}
}
}
template <typename T>
HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
HWY_NOINLINE void FFWNoVit(Activations& activations,
const LayerWeightsPtrs<T>* layer_weights) {
PROFILER_ZONE("Gen.FFW");
const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
const bool add_bias = layer_weights->layer_config.ff_biases;
const float* bias1 =
@ -802,56 +816,48 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
auto multiplier = RowPtrFromMat(allocator, activations.C2);
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
using WeightT = typename decltype(layer_weights->gating_einsum_w)::T;
// gating_einsum_w holds two half-matrices. We plan to change the importer to
// avoid this confusion by splitting into gating_einsum_w1 and
// gating_einsum_w2.
// gating_einsum_w2. TODO: move into Reshape().
const bool split = layer_weights->gating_einsum_w.HasPtr();
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
decltype(w1) w2;
ConstMat<WeightT> w1(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w1);
ConstMat<WeightT> w2(split ? layer_weights->gating_einsum_w
: layer_weights->gating_einsum_w2);
if (split) {
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w);
w2.ofs = w2.Row(ffh_hidden_dim);
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
w1.ShrinkRows(ffh_hidden_dim);
w2.ShrinkRows(ffh_hidden_dim);
} else {
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2);
}
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
// Compute the hidden layer activations.
MatMul(x, w1, bias1, *activations.env, hidden_activations);
MatMul(x, w2, bias2, *activations.env, multiplier);
MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env,
hidden_activations);
MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier);
// Activation (Gelu) and maybe multiply by gate. Store activations in act.
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
multiplier.Row(0), ffh_hidden_dim * num_interleaved);
ActivationBatched(layer_weights->layer_config.activation, activations.C1,
&activations.C2);
// Hidden layer -> output layer.
auto activations_mat = MakeConstMat(
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim),
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env,
ffw_out);
}
// Same as FFWNoVit, but with different layer_weights members and no second
// gating matrix.
template <typename T>
HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
HWY_NOINLINE void FFWVit(Activations& activations,
const LayerWeightsPtrs<T>* layer_weights) {
PROFILER_ZONE("Gen.FFW");
const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim;
HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize());
PROFILER_ZONE("Gen.FFW.ViT");
const bool add_bias = layer_weights->layer_config.ff_biases;
const float* bias1 =
@ -860,30 +866,21 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
auto hidden_activations = RowPtrFromMat(allocator, activations.C1);
auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out);
// Compute the hidden layer activations.
MatMul(x, w1, bias1, *activations.env, hidden_activations);
MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1,
*activations.env, hidden_activations);
// Activation (Gelu), store in act.
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
multiplier.Row(0), ff_hidden_dim * num_interleaved);
ActivationBatched(layer_weights->layer_config.activation, activations.C1);
// Hidden layer -> output layer.
auto activations_mat = MakeConstMat(hidden_activations.Row(0),
Extents2D(num_interleaved, ff_hidden_dim),
hidden_activations.Stride());
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias,
*activations.env, ffw_out);
}
// `batch_idx` indicates which row of `x` to write to.
@ -898,23 +895,23 @@ template <typename T>
HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x,
MatStorageT<float>& x,
const ImageTokens* image_tokens,
size_t& image_token_position) {
// Image tokens just need to be copied.
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM &&
image_tokens != nullptr && token == -2 &&
image_token_position < image_tokens->BatchSize()) {
hwy::CopyBytes(image_tokens->Batch(image_token_position),
x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0]));
image_token_position < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx),
x.Cols() * x.ElementBytes());
image_token_position++;
return;
}
if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA &&
image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) {
hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx),
x.Cols() * sizeof(x.Const()[0]));
image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) {
hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx),
x.Cols() * x.ElementBytes());
return;
}
@ -934,12 +931,12 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim);
const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0),
embedding_ofs + model_dim);
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Batch(batch_idx),
DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx),
model_dim);
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
x.Batch(batch_idx), model_dim);
x.Row(batch_idx), model_dim);
if (weights.weights_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos);
AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos);
}
}
@ -951,29 +948,28 @@ template <typename T>
HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
size_t pos_in_prompt,
const ModelWeightsPtrs<T>& weights,
RowVectorBatch<float>& x,
MatStorageT<float>& x,
const ImageTokens* image_tokens) {
size_t image_token_position = 0;
EmbedMMToken<T>(token, batch_idx, pos, pos_in_prompt, weights, x,
image_tokens, image_token_position);
}
template <typename Weights, typename T>
HWY_NOINLINE void ResidualConnection(
size_t num_interleaved, const T* HWY_RESTRICT other, T* HWY_RESTRICT x,
const LayerWeightsPtrs<Weights>* layer_weights, bool is_attention) {
template <typename T2, class LayerWeights>
HWY_NOINLINE void ResidualConnection(const MatPtrT<T2>& other,
MatPtrT<float>& HWY_RESTRICT x,
const LayerWeights* layer_weights,
bool is_attention) {
// ResidualType::Add
AddFromBatched(num_interleaved, other, x,
layer_weights->layer_config.model_dim);
AddFromBatched(other, x);
}
template <typename WeightT, typename InOutT>
void PostNorm(PostNormType post_norm, size_t num_interleaved,
const WeightT& weights, InOutT* inout) {
void PostNorm(PostNormType post_norm, const MatPtrT<WeightT>& weights,
MatPtrT<InOutT>& inout) {
HWY_DASSERT(weights.Rows() == 1);
if (post_norm == PostNormType::Scale) {
RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout,
weights.Cols());
RMSNormInplaceBatched(weights, inout);
}
}
@ -985,39 +981,33 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
Activations& activations,
const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches) {
const size_t model_dim = activations.weights_config.model_dim;
const size_t num_interleaved = num_tokens * queries_pos.size();
auto type = layer_weights->layer_config.type;
RMSNormBatched(num_interleaved, activations.x.All(),
layer_weights->pre_attention_norm_scale.PackedScale1(),
activations.pre_att_rms_out.All(), model_dim);
RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale,
activations.pre_att_rms_out);
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
activations, layer_weights, div_seq_len, kv_caches);
PostNorm(layer_weights->layer_config.post_norm, num_interleaved,
layer_weights->post_attention_norm_scale,
activations.att_sums.All());
PostNorm(layer_weights->layer_config.post_norm,
layer_weights->post_attention_norm_scale, activations.att_sums);
ResidualConnection(num_interleaved, activations.att_sums.All(),
activations.x.All(), layer_weights, /*is_attention=*/true);
ResidualConnection(activations.att_sums, activations.x, layer_weights,
/*is_attention=*/true);
RMSNormBatched(num_interleaved, activations.x.All(),
layer_weights->pre_ffw_norm_scale.PackedScale1(),
activations.bf_pre_ffw_rms_out.All(), model_dim);
RMSNormBatched(activations.x, layer_weights->pre_ffw_norm_scale,
activations.pre_ffw_rms_out);
if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
FFWVit(activations, num_interleaved, layer_weights);
FFWVit(activations, layer_weights);
} else {
FFWNoVit(activations, num_interleaved, layer_weights);
FFWNoVit(activations, layer_weights);
}
PostNorm(layer_weights->layer_config.post_norm, num_interleaved,
layer_weights->post_ffw_norm_scale, activations.ffw_out.All());
PostNorm(layer_weights->layer_config.post_norm,
layer_weights->post_ffw_norm_scale, activations.ffw_out);
ResidualConnection(num_interleaved, activations.ffw_out.All(),
activations.x.All(), layer_weights,
ResidualConnection(activations.ffw_out, activations.x, layer_weights,
/*is_attention=*/false);
}
@ -1034,38 +1024,37 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
auto type = layer_weights->layer_config.type;
HWY_DASSERT(type == LayerAttentionType::kVit);
(void)type;
(void)model_dim;
auto& x = activations.x;
HWY_DASSERT(x.BatchSize() == num_tokens);
HWY_DASSERT(x.Rows() == num_tokens);
HWY_DASSERT(x.Cols() == model_dim);
// y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out
LayerNormBatched(num_tokens, x.All(),
layer_weights->vit.layer_norm_0_scale.PackedScale1(),
layer_weights->vit.layer_norm_0_bias.PackedScale1(),
activations.pre_att_rms_out.All(), model_dim);
LayerNormBatched(x, layer_weights->vit.layer_norm_0_scale,
layer_weights->vit.layer_norm_0_bias,
activations.pre_att_rms_out);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
// y ~ att_sums
VitAttention<T>(num_tokens, layer, activations, layer_weights)();
// x = out["+sa"] = x + y
AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), model_dim);
AddFromBatched(activations.att_sums, x);
// y = nn.LayerNorm()(x)
// y ~ bf_pre_ffw_rms_out
LayerNormBatched(num_tokens, x.All(),
layer_weights->vit.layer_norm_1_scale.PackedScale1(),
layer_weights->vit.layer_norm_1_bias.PackedScale1(),
activations.bf_pre_ffw_rms_out.All(), model_dim);
// y ~ pre_ffw_rms_out
LayerNormBatched(x, layer_weights->vit.layer_norm_1_scale,
layer_weights->vit.layer_norm_1_bias,
activations.pre_ffw_rms_out);
// y = out["mlp"] = MlpBlock(...)(y)
// y ~ ffw_out
FFWVit(activations, num_tokens, layer_weights);
FFWVit(activations, layer_weights);
// x = out["+mlp"] = x + y
AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), model_dim);
AddFromBatched(activations.ffw_out, x);
}
// Prefill() and Transformer() increment positions in-place.
@ -1094,7 +1083,7 @@ HWY_NOINLINE void Prefill(
// intensity, and so we are eventually compute-limited. We could devote some
// threads to parallelizing over queries, but for simplicity we assign them
// all to MatMul.
const size_t max_tbatch_size = activations.x.BatchSize();
const size_t max_tbatch_size = activations.x.Rows();
// For each query. `qi` is within the batch, not the global query index.
for (size_t qi = 0; qi < num_queries; ++qi) {
@ -1131,6 +1120,7 @@ HWY_NOINLINE void Prefill(
tbatch_start += max_tbatch_size) {
const size_t tbatch_size =
HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start);
activations.SetBatchSize(tbatch_size);
// Fill activations.x (much faster than TransformerLayer).
size_t image_token_position = 0;
@ -1201,13 +1191,14 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3)
// image_patches is (256, 14 * 14 * 3)
// This could be done as one MatMul like:
// RowVectorBatch<float> image_patches(kSeqLen, kPatchSize);
// MatStorageT<float> image_patches("patches", Extents2D(kSeqLen,
// kPatchSize), MatPadding::kPacked);
// [Get patches]
// MatMul(
// MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel),
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
// RowPtrF(activations.x.All(), kVitModelDim));
// RowPtrF(activations.x.Row(0), kVitModelDim));
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and
@ -1216,11 +1207,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(),
weights.vit_img_embedding_bias.PackedScale1(),
activations.x.Batch(i), activations.env->ctx.pools.Pool(0));
activations.x.Row(i), activations.env->ctx.pools.Pool(0));
}
// Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(),
seq_len * model_dim);
AddFromBatched(weights.vit_img_pos_embedding, activations.x);
}
// Prefills the image tokens with the ViT encoder.
@ -1232,7 +1222,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
PROFILER_ZONE("Gen.PrefillVit");
const size_t num_tokens = weights.weights_config.vit_config.seq_len;
const size_t vit_model_dim = weights.weights_config.vit_config.model_dim;
HWY_ASSERT(num_tokens == activations.x.BatchSize());
HWY_ASSERT(num_tokens == activations.x.Rows());
// Embed the image patches.
EmbedImagePatches(image, weights, activations);
// Go through all layers.
@ -1243,24 +1233,21 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
VitTransformerLayer(num_tokens, layer, layer_weights, activations);
}
// Final Layernorm.
LayerNormBatched(num_tokens, activations.x.All(),
weights.vit_encoder_norm_scale.PackedScale1(),
weights.vit_encoder_norm_bias.PackedScale1(),
activations.x.All(), vit_model_dim);
LayerNormBatched(activations.x, weights.vit_encoder_norm_scale,
weights.vit_encoder_norm_bias, activations.x);
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(),
vit_model_dim);
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0,
activations.x.Row(0), vit_model_dim);
}
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
ConstMatFromWeights(weights.vit_img_head_kernel),
MatMul(activations.x, weights.vit_img_head_kernel,
weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromBatch(activations.env->ctx.allocator, image_tokens));
RowPtrFromMat(activations.env->ctx.allocator, image_tokens));
}
// Generates one token for each query. `queries_token` is the previous token
@ -1272,7 +1259,6 @@ HWY_NOINLINE void Transformer(
Activations& activations, const hwy::Divisor& div_seq_len,
const KVCaches& kv_caches, const LayersOutputFunc& layers_output,
const ActivationsObserverFunc& activations_observer) {
const size_t model_dim = weights.weights_config.model_dim;
const size_t num_queries = queries_token.size();
HWY_DASSERT(queries_pos.size() == num_queries);
HWY_DASSERT(queries_prefix_end.size() == num_queries);
@ -1302,8 +1288,7 @@ HWY_NOINLINE void Transformer(
}
}
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(),
activations.x.All(), model_dim);
RMSNormInplaceBatched(weights.final_norm_scale, activations.x);
if (activations_observer) {
activations_observer(queries_pos, -1, activations);
@ -1395,18 +1380,18 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
runtime_config.activations_observer);
// queries_pos are incremented by Transformer.
HWY_DASSERT(num_queries == activations.x.Rows());
bool all_queries_eos = true;
{
PROFILER_ZONE("Gen.EmbeddingMatmul");
// Compute logits from last layer activations.
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
MatMul(activations.x, weights.embedder_input_embedding,
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.env->ctx.allocator, activations.logits));
RowPtrFromMat(activations.env->ctx.allocator, activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
float* HWY_RESTRICT logits = activations.logits.Row(query_idx);
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
const TokenAndProb tp = sample_token(logits, vocab_size);
timing_info.NotifyGenerated();
@ -1460,7 +1445,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096.
HWY_ASSERT(num_queries <= activations.x.BatchSize());
HWY_ASSERT(num_queries <= activations.x.Rows());
HWY_ASSERT(queries_pos_in.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
const hwy::Divisor div_seq_len(static_cast<uint32_t>(kv_caches[0].seq_len));
@ -1475,12 +1460,11 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
// If tbatch is larger than the qbatch we already have in `activations`, then
// allocate prefill_activations, otherwise reuse.
const bool use_prefill_activations =
runtime_config.prefill_tbatch_size > activations.x.BatchSize();
Activations prefill_activations(weights.weights_config);
if (use_prefill_activations) {
prefill_activations.Allocate(runtime_config.prefill_tbatch_size,
activations.env);
}
runtime_config.prefill_tbatch_size > activations.x.Rows();
Activations prefill_activations(
weights.weights_config,
use_prefill_activations ? runtime_config.prefill_tbatch_size : 0,
activations.env);
Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end,
query_idx_start, weights,
use_prefill_activations ? prefill_activations : activations,
@ -1534,8 +1518,7 @@ void GenerateSingleT(const ModelStore& model,
const size_t qbatch_start = 0;
// TODO: move into Gemma?
Activations activations(model.Config());
activations.Allocate(kNumQueries, env);
Activations activations(model.Config(), kNumQueries, env);
const QueriesPromptTokens queries_prompt(&prompt, kNumQueries);
QueriesPos queries_pos(&pos, kNumQueries);
@ -1558,7 +1541,7 @@ void GenerateBatchT(const ModelStore& model,
TimingInfo& timing_info) {
const size_t num_queries = queries_prompt.size();
HWY_ASSERT(queries_pos.size() == num_queries);
HWY_ASSERT(kv_caches.size() == num_queries);
HWY_ASSERT(kv_caches.size() >= num_queries);
// Griffin does not support query batching.
size_t max_qbatch_size = runtime_config.decode_qbatch_size;
for (const LayerConfig& layer_config : model.Config().layer_configs) {
@ -1568,14 +1551,14 @@ void GenerateBatchT(const ModelStore& model,
}
}
Activations activations(model.Config());
activations.Allocate(max_qbatch_size, env);
Activations activations(model.Config(), max_qbatch_size, env);
for (size_t qbatch_start = 0; qbatch_start < num_queries;
qbatch_start += max_qbatch_size) {
// Generate one batch of tokens from `qbatch_size` queries.
const size_t qbatch_size =
HWY_MIN(num_queries - qbatch_start, max_qbatch_size);
activations.SetBatchSize(qbatch_size);
const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start],
qbatch_size);
QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size);
@ -1601,8 +1584,7 @@ void GenerateImageTokensT(const ModelStore& model,
ModelConfig vit_config = GetVitConfig(model.Config());
prefill_runtime_config.prefill_tbatch_size =
vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim);
Activations prefill_activations(vit_config);
prefill_activations.Allocate(vit_config.seq_len, env);
Activations prefill_activations(vit_config, vit_config.seq_len, env);
// Weights are for the full PaliGemma model, not just the ViT part.
PrefillVit(weights, prefill_runtime_config, image, image_tokens,
prefill_activations);

View File

@ -32,7 +32,6 @@
#include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h"
#include "util/basics.h" // TokenAndProb
#include "util/mat.h" // RowVectorBatch
#include "util/threading_context.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports

View File

@ -25,10 +25,11 @@
#include <random>
#include <string>
#include "io/io.h" // Path
#include "ops/matmul.h" // MMStorage::kMax*
#include "io/io.h" // Path
#include "ops/matmul.h" // MMStorage::kMax*
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "util/basics.h" // Tristate
#include "util/mat.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // HWY_ABORT
@ -74,9 +75,9 @@ using QueriesPromptTokens = hwy::Span<const PromptTokens>;
using QueriesToken = hwy::Span<const int>;
using QueriesPos = hwy::Span<const size_t>;
// ImageTokens are represented as a RowVectorBatch, where each "batch" index
// corresponds to a token for an image patch as computed by the image encoder.
using ImageTokens = RowVectorBatch<float>;
// ImageTokens are represented as a matrix, where each row corresponds
// to a token for an image patch as computed by the image encoder.
using ImageTokens = MatStorageT<float>;
// StreamFunc is called with (token, probability). For prompt tokens,
// probability is 0.0f. StreamFunc should return false to stop generation and

View File

@ -15,91 +15,69 @@
#include "gemma/kv_cache.h"
#include <algorithm>
#include <algorithm> // std::copy
#include "gemma/configs.h"
#include "util/mat.h" // ZeroInit
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // ZeroBytes
namespace gcpp {
void KVCache::ZeroGriffinCache() {
if (conv1d_cache_size != 0) {
hwy::ZeroBytes(conv1d_cache.get(),
conv1d_cache_size * sizeof(conv1d_cache[0]));
}
if (rglru_cache_size != 0) {
hwy::ZeroBytes(rglru_cache.get(),
rglru_cache_size * sizeof(rglru_cache[0]));
if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache);
if (rglru_cache.HasPtr()) ZeroInit(rglru_cache);
}
static size_t GriffinConv1dCols(const ModelConfig& config) {
size_t conv1d_width = 0;
for (const auto& layer_config : config.layer_configs) {
conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width);
}
return conv1d_width == 0 ? 0 : conv1d_width - 1;
}
// prefill_tbatch_size is the maximum number of tokens from one query to
// prefill at a time.
KVCache KVCache::Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache kv_cache = {};
const size_t size_cache_pos = weights_config.CachePosSize();
KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size)
: griffin_layers(
config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)),
griffin_conv1d_cols(GriffinConv1dCols(config)),
// TODO(patrickms): Add query batching support for Griffin.
conv1d_cache(
"conv1d_cache",
Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim),
MatPadding::kOdd),
rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim),
MatPadding::kOdd) {
// TODO: move to MatStorageT.
const size_t size_cache_pos = config.CachePosSize();
if (size_cache_pos != 0) {
// Allocate more so that prefill can always access one batch, even if
// near the end of the sequence.
kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size;
kv_cache.kv_cache =
hwy::AllocateAligned<float>(kv_cache.seq_len * size_cache_pos);
seq_len = config.seq_len + prefill_tbatch_size;
kv_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
}
const size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
// TODO(patrickms): Add query batching support for Griffin.
if (num_griffin_layers > 0) {
uint32_t conv1d_width = 0;
for (const auto& layer_config : weights_config.layer_configs) {
conv1d_width = std::max(conv1d_width, layer_config.conv1d_width);
}
const size_t conv1d_cache_size =
num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) *
weights_config.model_dim;
kv_cache.conv1d_cache_size = conv1d_cache_size;
if (conv1d_cache_size != 0) {
kv_cache.conv1d_cache = hwy::AllocateAligned<float>(conv1d_cache_size);
}
const size_t rglru_cache_size =
num_griffin_layers * weights_config.model_dim;
kv_cache.rglru_cache_size = rglru_cache_size;
if (rglru_cache_size != 0) {
kv_cache.rglru_cache = hwy::AllocateAligned<float>(rglru_cache_size);
}
} // num_griffin_layers
return kv_cache;
}
KVCache KVCache::Copy(const ModelConfig& weights_config,
size_t prefill_tbatch_size) {
KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size);
KVCache copy(weights_config, prefill_tbatch_size);
const size_t size_cache_pos = weights_config.CachePosSize();
if (size_cache_pos != 0) {
std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len,
kv_cache_copy.kv_cache.get());
copy.kv_cache.get());
}
const size_t num_griffin_layers = weights_config.NumLayersOfType(
LayerAttentionType::kGriffinRecurrentBlock);
if (num_griffin_layers > 0) {
if (conv1d_cache_size != 0) {
std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size,
kv_cache_copy.conv1d_cache.get());
}
if (rglru_cache_size != 0) {
std::copy(rglru_cache.get(),
rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]),
kv_cache_copy.rglru_cache.get());
}
if (conv1d_cache.HasPtr()) {
CopyMat(conv1d_cache, copy.conv1d_cache);
}
return kv_cache_copy;
if (rglru_cache.HasPtr()) {
CopyMat(rglru_cache, copy.rglru_cache);
}
return copy;
}
} // namespace gcpp

View File

@ -19,33 +19,31 @@
#include <stddef.h>
#include "gemma/configs.h" // ModelConfig
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
namespace gcpp {
struct KVCache {
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
KVCache() = default; // for std::vector.
KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
// (kConv1dWidth - 1) * kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> conv1d_cache;
size_t conv1d_cache_size = 0;
// kModelDim * kGriffinLayers
hwy::AlignedFreeUniquePtr<float[]> rglru_cache;
size_t rglru_cache_size = 0;
// Returns a deep copy of the KVCache.
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
size_t griffin_layers = 0;
size_t griffin_conv1d_cols = 0;
// griffin_layers, griffin_conv1d_cols * config.model_dim
MatStorageT<float> conv1d_cache;
MatStorageT<float> rglru_cache; // griffin_layers, config.model_dim
// Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache
// and rglru_cache.
void ZeroGriffinCache();
static KVCache Create(const ModelConfig& weights_config,
size_t prefill_tbatch_size);
size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size
// Returns a deep copy of the KVCache.
KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size);
// seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2
hwy::AlignedFreeUniquePtr<float[]> kv_cache;
};
} // namespace gcpp

View File

@ -95,24 +95,25 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
const ModelConfig& config = gemma.GetModelConfig();
std::mt19937 gen;
InitGenerator(inference, gen);
const bool have_image = !inference.image_file.path.empty();
Image image;
ImageTokens image_tokens;
const size_t pool_dim = config.vit_config.pool_dim;
ImageTokens image_tokens(
"image_tokens",
have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim),
config.model_dim)
: Extents2D(0, 0),
MatPadding::kOdd);
if (have_image) {
size_t pool_dim = gemma.GetModelConfig().vit_config.pool_dim;
image_tokens =
ImageTokens(gemma.Env().ctx.allocator,
Extents2D(gemma.GetModelConfig().vit_config.seq_len /
(pool_dim * pool_dim),
gemma.GetModelConfig().model_dim));
HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA ||
gemma.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA ||
config.wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
const size_t image_size = config.vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
@ -138,7 +139,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
std::cerr << "." << std::flush;
}
return true;
} else if (gemma.GetModelConfig().IsEOS(token)) {
} else if (config.IsEOS(token)) {
if (inference.verbosity >= 2) {
std::cout << "\n[ End ]\n";
}
@ -191,8 +192,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
size_t prefix_end = 0;
if (have_image) {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
gemma.GetModelConfig().wrapping, abs_pos,
prompt_string, image_tokens.BatchSize());
config.wrapping, abs_pos, prompt_string,
image_tokens.Rows());
runtime_config.image_tokens = &image_tokens;
prompt_size = prompt.size();
// The end of the prefix for prefix-LM style attention in Paligemma.
@ -203,8 +204,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
// runtime_config.prefill_tbatch_size = prompt_size;
} else {
prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(),
gemma.GetModelConfig().wrapping, abs_pos,
prompt_string);
config.wrapping, abs_pos, prompt_string);
prompt_size = prompt.size();
}
@ -228,8 +228,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
}
// Prepare for the next turn. Works only for PaliGemma.
if (!inference.multiturn ||
gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) {
if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(inference, gen);
} else {
@ -254,8 +253,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading,
MatMulEnv env(MakeMatMulEnv(threading));
if (inference.verbosity >= 2) env.print_best = true;
const Gemma gemma(loader, env);
KVCache kv_cache =
KVCache::Create(gemma.GetModelConfig(), inference.prefill_tbatch_size);
KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size);
if (inference.verbosity >= 1) {
std::string instructions =

View File

@ -92,9 +92,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N);
RowVectorBatch<TC> c_slow_batch =
AllocateAlignedRows<TC>(allocator, C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
if (add) {
@ -104,11 +103,9 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool);
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
const auto A = ConstMatFromWeights(a);
const auto B = ConstMatFromWeights(b_trans);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
// Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12;
@ -118,7 +115,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel);
BindB(allocator, B_extents.rows, sizeof(TC), ConstMat<TB>(b_trans),
env.parallel);
BindC(allocator, A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault;
@ -133,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Until enough samples collected *after* autotuning finished:
while (times.size() < num_samples) {
const double t0 = hwy::platform::Now();
per_key = MatMul(A, B, add_row, env, C);
per_key = MatMul(a, b_trans, add_row, env, C);
const double t1 = hwy::platform::Now();
double elapsed = t1 - t0;
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);

View File

@ -26,6 +26,7 @@
#include <cmath>
#include <random>
#include "compression/compress.h"
#include "compression/shared.h"
#include "util/allocator.h"
#include "util/test_util.h"
@ -999,7 +1000,6 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot
const Allocator& allocator = gcpp::ThreadingContext::Get().allocator;
CompressWorkingSet work;
std::mt19937 rng;
rng.seed(12345);
@ -1010,22 +1010,22 @@ struct TestShortDotsT {
// GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`,
// hence they require padding to one vector.
const size_t padded_num = hwy::RoundUpTo(num, N);
const size_t packed_num = CompressedArrayElements<Packed>(num);
RowVectorBatch<float> raw_w(allocator, Extents2D(1, padded_num));
RowVectorBatch<float> raw_v(allocator, Extents2D(1, padded_num));
RowVectorBatch<Packed> weights(allocator, Extents2D(1, packed_num));
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
RowVectorBatch<T> vectors(allocator, Extents2D(1, num));
const PackedSpan<T> v(vectors.Batch(0), num);
MatStorageT<float> raw_w("raw_w", padded_num);
MatStorageT<float> raw_v("raw_v", padded_num);
MatStorageT<Packed> weights("weights", padded_num);
const PackedSpan<Packed> w = weights.Span();
MatStorageT<T> vectors("vectors", padded_num);
const PackedSpan<T> v = vectors.Span();
RowVectorBatch<double> bufs(allocator, Extents2D(1, num));
double* HWY_RESTRICT buf = bufs.Batch(0);
MatStorageT<double> bufs("bufs", num);
double* HWY_RESTRICT buf = bufs.Packed();
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
GenerateWellConditionedInputs(num, raw_w.All(), rng, w, work);
GenerateWellConditionedInputs(num, raw_v.All(), rng, v, work);
GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work);
GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work);
const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf);
const float dot_exact =
ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf);
float dots[kVariants];
for (size_t variant = 0; variant < kVariants; ++variant) {
// Here Packed is not always float, so we must not call kDouble.
@ -1106,7 +1106,6 @@ void TestAllDot() {
threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext::SetArgs(threading_args);
ThreadingContext& ctx = ThreadingContext::Get();
const Allocator& allocator = ctx.allocator;
{ // ensure no profiler zones are active
const hn::ScalableTag<float> df;
@ -1118,16 +1117,17 @@ void TestAllDot() {
constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024;
RowVectorBatch<float> a(allocator, Extents2D(kMaxWorkers, num));
RowVectorBatch<float> b(allocator, Extents2D(kMaxWorkers, num));
RowVectorBatch<double> bufs(allocator, Extents2D(kMaxWorkers, num));
MatStorageT<float> a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
MatStorageT<float> b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd);
MatStorageT<double> bufs("bufs", Extents2D(kMaxWorkers, num),
MatPadding::kOdd);
std::array<DotStats, kMaxWorkers> all_stats;
ctx.pools.Cluster(0, 0).Run(
0, kReps, [&](const uint32_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Batch(thread);
float* HWY_RESTRICT pb = b.Batch(thread);
double* HWY_RESTRICT buf = bufs.Batch(thread);
float* HWY_RESTRICT pa = a.Row(thread);
float* HWY_RESTRICT pb = b.Row(thread);
double* HWY_RESTRICT buf = bufs.Row(thread);
const PackedSpan<const float> a_span(pa, num);
DotStats& stats = all_stats[thread];
const double cond =

View File

@ -693,7 +693,6 @@ class MMScaleDemoteAdd {
// We manually unroll 2x for higher IPC in batch=1.
size_t col_c = range_nc.begin();
if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) {
HWY_UNROLL(1)
for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) {
VD a0, a1; // unused if !kAdd
if constexpr (kAdd) {
@ -860,9 +859,8 @@ class MMScaleDemoteAdd {
class MMPerPackage {
public:
template <typename TA>
MMPerPackage(const ConstMat<TA>& A, const MMArgs& args,
const MMConfig& config, size_t pkg_idx,
const IndexRange& range_np)
MMPerPackage(const MatPtrT<TA>& A, const MMArgs& args, const MMConfig& config,
size_t pkg_idx, const IndexRange& range_np)
: args_(args),
pkg_idx_(pkg_idx),
// May be overwritten with a view of A, if already BF16.
@ -1114,12 +1112,12 @@ class MMPerPackage {
});
}
// Decompresses all `M x K` from `A` into `pkg_A`. Assumes `TA` is a seekable
// Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable
// type (i.e., not NUQ) so we can use pointer arithmetic.
template <typename TA>
HWY_NOINLINE void DoDecompressA(const ConstMat<TA>& A, MMParA par_a) const {
const IndexRange all_M(0, A.extents.rows);
const IndexRange all_K(0, A.extents.cols);
HWY_NOINLINE void DoDecompressA(const MatPtrT<TA>& A, MMParA par_a) const {
const IndexRange all_M(0, A.Rows());
const IndexRange all_K(0, A.Cols());
HWY_DASSERT(all_K.Num() == A_.Cols());
const hn::ScalableTag<BF16> dbf;
@ -1131,10 +1129,9 @@ class MMPerPackage {
const size_t col0 = range_K.begin();
const size_t cols = range_K.Num();
// otherwise, padding overwrites neighbors
HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols);
HWY_DASSERT(cols % NBF == 0 || cols == A.Cols());
for (size_t row_a : range_M) {
const PackedSpan<const TA> from =
MakeSpan(A.ptr + A.Row(row_a) + col0, cols);
const PackedSpan<const TA> from = MakeSpan(A.Row(row_a) + col0, cols);
BF16* HWY_RESTRICT to = A_.Row(row_a) + col0;
DecompressAndZeroPad(dbf, from, 0, to, cols);
// Verify that we zero-padded.
@ -1174,18 +1171,14 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`.
template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
HWY_INLINE RowPtrBF DecompressA(const MatPtrT<TA>& A) const {
const Allocator& allocator = args_.env->ctx.allocator;
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) {
// Only if no zero-padding required.
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.extents.cols % NBF == 0)) {
const BF16* pos = A.ptr + A.Row(0);
return RowPtrBF(allocator, const_cast<BF16*>(pos), A.extents.cols,
A.Stride());
}
if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A);
}
if (HWY_LIKELY(autotune.Best())) {
@ -1196,7 +1189,7 @@ class MMPerPackage {
// First call: generate candidates.
if (HWY_UNLIKELY(!autotune.HasCandidates())) {
std::vector<MMParA> candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4};
if (A.extents.rows == 1) {
if (A.Rows() == 1) {
candidates.push_back(MMParA::kNone);
} else {
candidates.push_back(MMParA::kM);
@ -1247,7 +1240,7 @@ class MMPerPackage {
const MMArgs args_; // copy for locality
const size_t pkg_idx_;
RowPtrBF A_; // points into A or storage.
RowPtrBF A_; // points into A or pkg_A.
const IndexRange range_np_;
// From MMConfig:
@ -1276,9 +1269,8 @@ struct MMImpl {
// Called from `MatMul` from two places: either with the next autotune config,
// or with the best config.
template <typename TA, typename TB, typename TC>
static HWY_NOINLINE void DoMatMul(const ConstMat<TA>& A,
const ConstMat<TB>& B, const RowPtr<TC>& C,
const MMArgs& args,
static HWY_NOINLINE void DoMatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
const RowPtr<TC>& C, const MMArgs& args,
const MMConfig& config) {
MMZone matmul_zone;
matmul_zone.MaybeEnter("MM.DoMatMul", args);
@ -1313,7 +1305,7 @@ struct MMImpl {
//
// Uses considerable stack space: at least 40 KiB per thread.
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
const Allocator& allocator = env.ctx.allocator;
@ -1340,7 +1332,7 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
MMPerKey& per_key = env.per_key[index];
MMAutoTune<MMConfig>& tuner = per_key.autotune;
const MMArgs args(env, per_key, static_cast<double>(A.scale) * B.scale, add,
const MMArgs args(env, per_key, static_cast<double>(A.Scale()) * B.scale, add,
env.storage.Partial());
if (HWY_LIKELY(tuner.Best())) {
MMImpl::DoMatMul(A, B, C, args, *tuner.Best());
@ -1396,6 +1388,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
return &per_key;
}
template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
return MatMul(A, ConstMat<TB>(B), add, env, C);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -21,6 +21,7 @@
#include <stddef.h>
#include <stdint.h>
#include <memory> // std::unique_ptr
#include <vector>
// IWYU pragma: begin_exports
@ -207,24 +208,28 @@ class MMStorage {
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext` (used via `parallel`).
MMStorage(const Allocator& allocator, MMParallel& parallel)
// Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity.
: partial_storage_(
AllocateAlignedRows<double>(allocator, Extents2D(kMaxM, kMaxN))),
: partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN),
MatPadding::kOdd),
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(allocator, partial_storage_.All(), kMaxN,
StrideForCyclicOffsets(kMaxN, allocator.Quantum<double>())) {
partial_(allocator, partial_storage_.Row(0), kMaxN,
partial_storage_.Stride()) {
// Per-package allocation so each can decompress A into its own copy.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx] =
AllocateAlignedRows<BF16>(allocator, Extents2D(kMaxM, kMaxK));
pkg_A_[pkg_idx].reset(new MatStorageT<BF16>(
"pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd));
if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx);
if (!allocator.BindMemory(pkg_A_[pkg_idx].All(),
pkg_A_[pkg_idx].NumBytes(), node)) {
size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() *
pkg_A_[pkg_idx]->ElementBytes();
bytes = hwy::RoundDownTo(bytes, allocator.QuantumBytes());
if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
}
}
@ -234,22 +239,20 @@ class MMStorage {
BindC(allocator, kMaxM, partial_, parallel);
}
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
// non-const, because `RowPtr` requires a non-const pointer.
// Returns per-package matrix view.
RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
const Extents2D& extents) {
const Extents2D& extents) const {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
const size_t stride =
StrideForCyclicOffsets(extents.cols, allocator.Quantum<BF16>());
return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride);
return RowPtrBF(allocator, const_cast<BF16*>(pkg_A_[pkg_idx]->Row(0)),
extents.cols, pkg_A_[pkg_idx]->Stride());
}
RowPtrD Partial() const { return partial_; }
private:
RowVectorBatch<BF16> pkg_A_[MMParallel::kMaxPackages];
RowVectorBatch<double> partial_storage_;
std::unique_ptr<MatStorageT<BF16>> pkg_A_[MMParallel::kMaxPackages];
MatStorageT<double> partial_storage_;
RowPtrD partial_;
};
@ -608,6 +611,8 @@ struct MMPerKey {
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`.
struct MatMulEnv {
// Internally threaded; must not be called concurrently with the same
// `ThreadingContext`.
explicit MatMulEnv(ThreadingContext& ctx);
ThreadingContext& ctx;
@ -679,8 +684,8 @@ struct MMZone {
#endif // PROFILER_ENABLED
// Used for the A and B arguments of `MatMul`, which are always const.
// Create via MakeConstMat. This differs from `RowPtr` in that it supports the
// `ofs` required for compressed T.
// This differs from `RowPtr` in supporting the `ofs` required for compressed T.
// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr.
template <typename T>
struct ConstMat {
ConstMat() = default;
@ -689,6 +694,12 @@ struct ConstMat {
HWY_DASSERT(ptr != nullptr);
HWY_DASSERT(stride >= extents.cols);
}
// Non-explicit so that we can pass `MatPtr` directly to MatMul.
ConstMat(const MatPtrT<T>& m)
: ConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride()) {
scale = m.Scale();
}
size_t Row(size_t r) const {
if constexpr (HWY_IS_DEBUG_BUILD) {
if (r >= extents.rows) {
@ -727,31 +738,6 @@ struct ConstMat {
size_t ofs;
};
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
size_t stride) {
return ConstMat<T>(ptr, extents, stride);
}
// For A argument to MatMul (activations).
template <typename T>
ConstMat<T> ConstMatFromBatch(size_t batch_size,
const RowVectorBatch<T>& row_vectors) {
HWY_DASSERT(batch_size <= row_vectors.BatchSize());
return MakeConstMat(const_cast<T*>(row_vectors.Const()),
Extents2D(batch_size, row_vectors.Cols()),
row_vectors.Stride());
}
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
ConstMat<T> mat =
MakeConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride());
mat.scale = m.Scale();
return mat;
}
template <typename TB>
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
const ConstMat<TB>& B, MMParallel& parallel) {

View File

@ -57,10 +57,10 @@ namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
double MaxRowAbsSum(const MatStorageT<float>& a) {
double max_row_abs_sum = 0.0;
for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a.Batch(r);
for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Row(r);
double row_abs_sum = 0.0;
for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]);
@ -71,11 +71,11 @@ double MaxRowAbsSum(const RowVectorBatch<float>& a) {
}
// Returns the maximum absolute value of `a`.
float MaxAbs(const RowVectorBatch<float>& a) {
float MaxAbs(const MatStorageT<float>& a) {
float max_abs = 0.0f;
for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a.Batch(r);
for (size_t r = 0; r < a.Rows(); r++) {
const float* row = a.Row(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
}
}
@ -84,33 +84,29 @@ float MaxAbs(const RowVectorBatch<float>& a) {
// B is already transposed.
template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
void AssertClose(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const Allocator& allocator = ThreadingContext::Get().allocator;
const hn::ScalableTag<float> df;
const size_t cols = A.extents.cols;
const size_t B_rows = B.extents.rows;
const size_t cols = A.Cols();
const size_t B_rows = B.Rows();
// Round up for DecompressAndZeroPad.
RowVectorBatch<float> a_batch =
AllocateAlignedRows<float>(allocator, A.extents);
RowVectorBatch<float> b_trans_batch =
AllocateAlignedRows<float>(allocator, B.extents);
RowVectorBatch<float> c_batch =
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
RowVectorBatch<float> c_slow_batch =
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
for (size_t m = 0; m < A.extents.rows; ++m) {
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
a_batch.Batch(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
MatStorageT<float> a_batch("a_batch", A.Extents(), MatPadding::kOdd);
MatStorageT<float> b_trans_batch("b_trans_batch", B.Extents(),
MatPadding::kOdd);
MatStorageT<float> c_batch("c_batch", Extents2D(A.Rows(), B_rows),
MatPadding::kOdd);
MatStorageT<float> c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows),
MatPadding::kOdd);
for (size_t m = 0; m < A.Rows(); ++m) {
DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m),
B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Batch(m), B_rows);
c_slow_batch.Row(m), B_rows);
}
for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0,
b_trans_batch.Batch(n), cols);
DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n),
cols);
}
// MatMul rounds inputs to BF16, so error is proportional to the max input
@ -130,10 +126,10 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
}
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
for (size_t r = 0; r < A.extents.rows; r++) {
const float* expected_row = c_slow_batch.Batch(r);
const float* actual_row = c_batch.Batch(r);
for (size_t c = 0; c < B.extents.rows; c++) {
for (size_t r = 0; r < A.Rows(); r++) {
const float* expected_row = c_slow_batch.Row(r);
const float* actual_row = c_batch.Row(r);
for (size_t c = 0; c < B.Rows(); c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
@ -157,18 +153,17 @@ void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
// B is already transposed.
template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
HWY_INLINE void MatMulSlow(const MatPtrT<TA> A, const MatPtrT<TB> B,
const float* HWY_RESTRICT add_row, MatMulEnv& env,
const RowPtr<TC>& C) {
// TA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs.
static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32");
const float scale = A.scale * B.scale;
const float scale = A.Scale() * B.Scale();
const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const TB> b_span =
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
const PackedSpan<const TB> b_span = B.Span();
const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols());
@ -191,8 +186,8 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.Stride(),
A.ptr + A.Row(r), A.extents.cols));
add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r),
A.Cols()));
}
}
});
@ -225,26 +220,23 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
RowVectorBatch<TC> c_slow_batch =
AllocateAlignedRows<TC>(allocator, C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
MatStorageT<TC> c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd);
MatStorageT<TC> c_batch("c_batch", C_extents, MatPadding::kOdd);
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
add_storage.SetScale(1.0f);
const auto A = ConstMatFromWeights(a);
const auto B = ConstMatFromWeights(b_trans);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromBatch(allocator, c_slow_batch);
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
const RowPtr<TC> C_slow = RowPtrFromMat(allocator, c_slow_batch);
const RowPtr<TC> C = RowPtrFromMat(allocator, c_batch);
MatMulSlow(A, B, add_row, env, C_slow);
MatMulSlow(a, b_trans, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
AssertClose(A, B, C_slow, C, line);
MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C);
AssertClose(a, b_trans, C_slow, C, line);
if (per_key->autotune.Best()) break;
}
}

View File

@ -189,10 +189,11 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) {
} // namespace detail
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
const WeightT* HWY_RESTRICT weight,
OutT* HWY_RESTRICT out,
// `x_ofs` is the offset within `x`, required for NuqStream.
template <typename XT, typename WT, typename OT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x,
const WT* HWY_RESTRICT weight,
size_t w_ofs, OT* HWY_RESTRICT out,
const size_t size) {
PROFILER_FUNC;
@ -203,17 +204,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
const VF mul = hn::Set(df, detail::RMSNormMul(x, size));
const auto packed_w = MakeSpan(weight, size);
const auto packed_v = MakeSpan(x, size);
const auto packed_x = MakeSpan(x, size);
const auto packed_w = MakeSpan(weight, w_ofs + size);
const auto packed_out = MakeSpan(out, size);
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
HWY_DASSERT(size % (2 * NF) == 0);
for (size_t i = 0; i < size; i += 2 * NF) {
VF v0, v1, w0, w1;
Decompress2(df, packed_v, i, v0, v1);
Decompress2(df, packed_w, i, w0, w1);
const VF m0 = hn::Mul(mul, v0);
const VF m1 = hn::Mul(mul, v1);
VF x0, x1, w0, w1;
Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_w, w_ofs + i, w0, w1);
const VF m0 = hn::Mul(mul, x0);
const VF m1 = hn::Mul(mul, x1);
// (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1);
@ -222,10 +223,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x,
}
// Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer.
template <typename WeightT, typename VecT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout,
const size_t size) {
template <typename WT, typename XT>
HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight,
size_t w_ofs,
XT* HWY_RESTRICT inout,
const size_t size) {
PROFILER_FUNC;
namespace hn = hwy::HWY_NAMESPACE;
@ -235,72 +237,112 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const VF mul = hn::Set(df, detail::RMSNormMul(inout, size));
const auto packed_w = MakeSpan(weight, size);
const auto packed_v = MakeSpan(inout, size);
const auto packed_w = MakeSpan(weight, w_ofs + size);
const auto packed_x = MakeSpan(inout, size);
HWY_DASSERT(size % (2 * MaxLanes(df)) == 0);
HWY_DASSERT(size % (2 * NF) == 0);
for (size_t i = 0; i < size; i += 2 * NF) {
VF v0, v1, w0, w1;
Decompress2(df, MakeConst(packed_v), i, v0, v1);
Decompress2(df, packed_w, i, w0, w1);
const VF m0 = hn::Mul(mul, v0);
const VF m1 = hn::Mul(mul, v1);
VF x0, x1, w0, w1;
Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_w, w_ofs + i, w0, w1);
const VF m0 = hn::Mul(mul, x0);
const VF m1 = hn::Mul(mul, x1);
// (1+weight) * m = m + weight*m = one FMA.
const VF out0 = hn::MulAdd(m0, w0, m0);
const VF out1 = hn::MulAdd(m1, w1, m1);
Compress2(df, out0, out1, packed_v, i);
Compress2(df, out0, out1, packed_x, i);
}
}
// Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm.
template <typename T>
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu,
T& mu2) {
template <typename XT>
HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size,
double& mu, double& mu2) {
HWY_ASSERT(size > 0);
double sum = 0.0;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f;
sum2 += f * f;
}
mu = sum / size;
mu2 = sum2 / size;
const hn::ScalableTag<float> df;
// Use the existing Sum and Dot kernels for simplicity. The second pass
// is likely not too expensive because it will be in L1.
const double sum = Sum(df, x, size);
// We only have one array, so calling `DecompressAndCall` instead of `Dot``
// avoids loading the 'second' vector again.
const double sum2 =
DecompressAndCall(df, MakeSpan(x, size), DotKernelDouble());
const double inv_size = 1.0 / static_cast<double>(size);
mu = sum * inv_size;
mu2 = sum2 * inv_size;
}
// Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE void ScalarLayerNorm(const VecT* x,
const WeightT* HWY_RESTRICT scale,
const WeightT* HWY_RESTRICT bias,
OutT* out,
size_t size) {
constexpr float kEps = 1e-6f;
VecT mu, mu2;
ScalarMus(x, size, mu, mu2);
VecT var = mu2 - mu * mu;
VecT zero = 0.0f;
var = HWY_MAX(var, zero);
var = 1.0f / sqrtf(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OutT>((v - mu) * s * var + b);
}
}
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x,
const WeightT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT bias,
OutT* out,
const size_t size) {
// x and out may be the same.
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
const WT* HWY_RESTRICT bias, OT* out, size_t size) {
PROFILER_FUNC;
// For now we only delegate to the scalar version.
// TODO: implement vectorized version.
ScalarLayerNorm(x, weight, bias, out, size);
namespace hn = hwy::HWY_NAMESPACE;
const hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
double mu, mu2;
ComputeMoments(x, size, mu, mu2);
double var = mu2 - mu * mu;
var = HWY_MAX(var, 0.0);
var = 1.0 / sqrt(var + 1E-6);
const VF vmu = hn::Set(df, static_cast<float>(mu));
const VF vvar = hn::Set(df, static_cast<float>(var));
const VF* HWY_RESTRICT pmu = &vmu;
const VF* HWY_RESTRICT pvar = &vvar;
const auto packed_x = MakeSpan(x, size);
const auto packed_scale = MakeSpan(scale, size);
const auto packed_bias = MakeSpan(bias, size);
const auto packed_out = MakeSpan(out, size);
// Loop body for one vector, called from main loop and remainder loop.
const auto norm = [pmu, pvar](VF x, VF s, VF add) HWY_ATTR -> VF {
const VF centered = hn::Sub(x, *pmu);
const VF mul = hn::Mul(s, *pvar);
return hn::MulAdd(centered, mul, add);
};
size_t i = 0;
if (size >= 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
VF x0, x1, s0, s1, add0, add1;
Decompress2(df, packed_x, i, x0, x1);
Decompress2(df, packed_scale, i, s0, s1);
Decompress2(df, packed_bias, i, add0, add1);
const VF n0 = norm(x0, s0, add0);
const VF n1 = norm(x1, s1, add1);
Compress2(df, n0, n1, packed_out, i);
}
}
const size_t remaining = size - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)];
HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)];
DecompressAndZeroPad(df, packed_x, i, buf_x, remaining);
DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining);
DecompressAndZeroPad(df, packed_bias, i, buf_bias, remaining);
const VF x0 = hn::Load(df, buf_x);
const VF x1 = hn::Load(df, buf_x + NF);
const VF s0 = hn::Load(df, buf_scale);
const VF s1 = hn::Load(df, buf_scale + NF);
const VF add0 = hn::Load(df, buf_bias);
const VF add1 = hn::Load(df, buf_bias + NF);
const VF n0 = norm(x0, s0, add0);
const VF n1 = norm(x1, s1, add1);
Compress2(df, n0, n1, MakeSpan(buf_out, 2 * NF), 0);
hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT));
}
}
static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
@ -447,39 +489,56 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
}
// Simple loops unless/until batch sizes are large enough to parallelize.
template <typename WeightT, typename OutT>
void RMSNormBatched(size_t num_tokens, const float* activations,
const WeightT* weights, OutT* out, const size_t model_dim) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations + token_idx * model_dim, weights,
out + token_idx * model_dim, model_dim);
}
template <typename XT, typename OT>
void RMSNormBatched(const MatPtrT<XT>& activations, const MatPtr& weights,
MatPtrT<OT>& out) {
HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == activations.Cols());
HWY_DASSERT(activations.SameShape(out));
CallUpcasted(&weights, [&](const auto* weights_t) {
for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) {
RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0,
out.Row(token_idx), activations.Cols());
}
});
}
// TODO: pass RowVectorBatch argument.
template <typename WeightT, typename InOutT>
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
InOutT* inout, const size_t model_dim) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
}
template <typename XT>
void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT<XT>& inout) {
HWY_DASSERT(weights.Rows() == 1);
HWY_DASSERT(weights.Cols() == inout.Cols());
CallUpcasted(&weights, [&](const auto* weights_t) {
for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) {
RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx),
inout.Cols());
}
});
}
template <typename VecT, typename WeightT, typename OutT>
void LayerNormBatched(size_t num_tokens, const VecT* x,
const WeightT* HWY_RESTRICT weight,
const WeightT* HWY_RESTRICT bias, OutT* out,
const size_t size) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size);
}
// x and out may be the same.
template <typename XT, typename OT>
void LayerNormBatched(const MatPtrT<XT>& x, const MatPtr& weight,
const MatPtr& bias, MatPtrT<OT>& out) {
HWY_DASSERT(weight.Cols() == bias.Cols());
HWY_DASSERT(weight.Cols() == x.Cols());
HWY_DASSERT(x.SameShape(out));
CallUpcastedSame(
&weight, &bias, [&](const auto* weight_t, const auto* bias_t) {
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
LayerNorm(x.Row(token_idx), weight_t->PackedScale1(),
bias_t->PackedScale1(), out.Row(token_idx), x.Cols());
}
});
}
static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other,
float* x, const size_t model_dim) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
model_dim);
static HWY_INLINE void AddFromBatched(const MatPtrT<float>& other,
MatPtrT<float>& x) {
HWY_DASSERT(x.SameShape(other));
for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) {
AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols());
}
}
@ -743,8 +802,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
HWY_ASSERT(k != 0);
HWY_ASSERT(k <= vocab_size);
std::vector<double> packed_token_probs;
for (int32_t i = 0; i < vocab_size; ++i) {
if (accept_token && !accept_token(StaticCast<int>(i), probabilities[i])) {
for (int32_t i = 0; i < static_cast<int32_t>(vocab_size); ++i) {
if (accept_token && !accept_token(i, probabilities[i])) {
continue;
}
packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i]));
@ -756,7 +815,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector<TokenAndProb> TopK(
std::vector<TokenAndProb> token_probs;
token_probs.reserve(k);
for (int32_t i = 0; i < k; ++i) {
for (int32_t i = 0; i < static_cast<int32_t>(k); ++i) {
token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i]));
}
return token_probs;
@ -770,7 +829,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
TopK(probabilities, vocab_size, k, accept_token);
std::vector<int> topk_indices(k);
std::vector<float> topk_probs(k);
for (int i = 0; i < k; ++i) {
for (size_t i = 0; i < k; ++i) {
topk_indices[i] = token_probs[i].token;
topk_probs[i] = token_probs[i].prob;
}
@ -788,7 +847,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
TopK(logits, vocab_size, k, accept_token);
std::vector<int> topk_indices(k);
std::vector<float> topk_logits(k);
for (int i = 0; i < token_logits.size(); ++i) {
for (size_t i = 0; i < token_logits.size(); ++i) {
topk_indices[i] = token_logits[i].token;
topk_logits[i] = token_logits[i].prob;
}
@ -807,20 +866,20 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Input has 4096 (64*64) rows, output has 256 (16*16) rows
// Each output row is the average of a 4x4 block of input rows
template <typename T>
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
const Allocator& allocator = ThreadingContext::Get().allocator;
MatStorageT<T> AvgPool4x4(MatStorageT<T>& input) {
const Extents2D extents = input.Extents();
// Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
// Create output with 256 rows and same number of columns
const size_t out_rows = 256; // 16 * 16 = 256 output rows
RowVectorBatch<T> result(allocator, Extents2D(out_rows, extents.cols));
MatStorageT<T> result("pool4x4", Extents2D(out_rows, extents.cols),
MatPadding::kOdd);
const size_t input_dim = 64; // Input is 64×64
const size_t output_dim = 16; // Output is 16×16
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) {
size_t out_idx = out_row_idx * output_dim + out_col_idx;
T* output_row = result.Batch(out_idx);
T* output_row = result.Row(out_idx);
// Initialize output row to zeros
std::fill(output_row, output_row + extents.cols, 0);
// Average 16 row vectors from a 4x4 block
@ -829,9 +888,9 @@ RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
size_t in_row_idx = out_row_idx * 4 + i;
size_t in_col_idx = out_col_idx * 4 + j;
size_t in_idx = in_row_idx * input_dim + in_col_idx;
const T* input_row = input.Batch(in_idx);
const T* input_row = input.Row(in_idx);
// Add each input row to the output
// TODO(philculliton): use AddFrom in ops-inl for a vectorized loop.
// TODO(philculliton): use AddFrom in `ops-inl` for a vectorized loop.
for (size_t col = 0; col < extents.cols; ++col) {
output_row[col] += input_row[col];
}

View File

@ -20,23 +20,22 @@
#include <cmath>
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/base.h"
namespace gcpp {
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
static inline HWY_MAYBE_UNUSED MatStorageT<float> CreateInvTimescale(
const Allocator& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));
MatStorageT<float> inv_timescale("inv_timescale", rope_dim / 2);
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const double freq_exponents =
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
// Replacing with expf(ln(1E4) * freq_exponents) changes results
// noticeably.
inv_timescale.Batch(0)[dim] =
inv_timescale.Packed()[dim] =
static_cast<float>(1.0 / std::pow(base_frequency, freq_exponents));
}
return inv_timescale;

View File

@ -34,7 +34,7 @@
#include "gemma/common.h" // ChooseQueryScale
#include "util/allocator.h"
#include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch
#include "util/mat.h" // MatStorageT
#include "util/test_util.h"
#include "util/threading_context.h"
#include "hwy/tests/hwy_gtest.h"
@ -391,7 +391,7 @@ void TestRopeAndMulBy() {
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B));
int dim_qkv = config.layer_configs[0].qkv_dim;
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv));
MatStorageT<float> x("x", dim_qkv);
std::mt19937 gen;
gen.seed(0x12345678);
@ -399,43 +399,43 @@ void TestRopeAndMulBy() {
auto random_float = [&r, &gen] { return r(gen); };
for (int i = 0; i < dim_qkv; ++i) {
x.All()[i] = random_float();
x.Packed()[i] = random_float();
}
const float qmul = ChooseQueryScale(config);
const float kmul = 1.0;
std::vector<float> qexpected(dim_qkv);
std::vector<float> qactual(dim_qkv);
std::vector<float> kexpected(dim_qkv);
std::vector<float> kactual(dim_qkv);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
MatStorageT<float> qexpected("qexpected", dim_qkv);
MatStorageT<float> qactual("qactual", dim_qkv);
MatStorageT<float> kexpected("kexpected", dim_qkv);
MatStorageT<float> kactual("kactual", dim_qkv);
MatStorageT<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) {
// Rope'd Q embeddings
hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv);
hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv);
ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(),
pos);
RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos);
CopyMat(x, qactual);
CopyMat(x, qexpected);
ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv,
inv_timescale.Packed(), pos);
RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(qactual[i], qexpected[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual[i];
EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4)
<< "qIndex:" << i << "qInput:" << qactual.Packed()[i];
}
// Rope'd K embeddings
hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv);
hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv);
ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(),
pos);
RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos);
CopyMat(x, kactual);
CopyMat(x, kexpected);
ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv,
inv_timescale.Packed(), pos);
RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos);
for (int i = 0; i < dim_qkv; ++i) {
EXPECT_NEAR(kactual[i], kexpected[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual[i];
EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4)
<< "kIndex:" << i << "kInput:" << kactual.Packed()[i];
}
}
}
@ -451,10 +451,9 @@ HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) {
}
// Supports bf16 and f32 inputs/outputs, which can be in-place.
template <typename VecT, typename WeightT, typename OutT>
HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
const WeightT* HWY_RESTRICT weight, OutT* out,
size_t size) {
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight,
OT* out, size_t size) {
constexpr float kEps = 1e-6f;
float ss = ScalarSquaredL2(x, size);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + kEps);
@ -462,32 +461,32 @@ HWY_NOINLINE void ScalarRMSNorm(const VecT* x,
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float w = hwy::ConvertScalarTo<float>(weight[j]);
// Note 1.0f centering here
out[j] = hwy::ConvertScalarTo<OutT>((1.0f + w) * (ss * v));
out[j] = hwy::ConvertScalarTo<OT>((1.0f + w) * (ss * v));
}
}
template <typename VecT, typename WeightT, typename OutT>
template <typename XT, typename WT, typename OT>
void TestRMSNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128;
HWY_ALIGN VecT vec[kSize];
HWY_ALIGN WeightT weight[kSize];
HWY_ALIGN OutT expected[kSize];
HWY_ALIGN OutT actual[kSize];
HWY_ALIGN XT vec[kSize];
HWY_ALIGN WT weight[kSize];
HWY_ALIGN OT expected[kSize];
HWY_ALIGN OT actual[kSize];
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
ScalarRMSNorm(vec, weight, expected, kSize);
RMSNorm(vec, weight, actual, kSize);
RMSNorm(vec, weight, 0, actual, kSize);
for (size_t i = 0; i < kSize; i++) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WT>(), TypeName<OT>(), i, e, a);
}
}
}
@ -526,24 +525,64 @@ void TestLayerNormSimple() {
}
}
// Note: there is no vectorized implementation of LayerNorm yet. So this test
// currently only checks that the scalar version can be called for the below
// combinations of float/BF16 inputs and outputs.
template <typename VecT, typename WeightT, typename OutT>
// Computes mean mu and mean of squares mu2 of a vector. Used in
// ScalarLayerNorm.
template <typename T>
HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, double& mu,
double& mu2) {
HWY_ASSERT(size > 0);
double sum = 0.0;
double sum2 = 0.0;
for (size_t i = 0; i < size; ++i) {
const float f = hwy::ConvertScalarTo<float>(a[i]);
sum += f;
sum2 += f * f;
}
mu = sum / size;
mu2 = sum2 / size;
}
// Compare py/flax/linen/normalization.py.
// out = (x - mean) * scale * rsqrt(var + epsilon) + bias
template <typename XT, typename WT, typename OT>
HWY_NOINLINE void ScalarLayerNorm(const XT* x, const WT* HWY_RESTRICT scale,
const WT* HWY_RESTRICT bias, OT* out,
size_t size) {
constexpr double kEps = 1e-6;
double mu, mu2;
ScalarMus(x, size, mu, mu2);
double var = mu2 - mu * mu;
constexpr double kZero = 0.0;
var = HWY_MAX(var, kZero);
var = 1.0 / sqrt(var + kEps);
for (size_t j = 0; j < size; j++) {
const float v = hwy::ConvertScalarTo<float>(x[j]);
const float s = hwy::ConvertScalarTo<float>(scale[j]);
const float b = hwy::ConvertScalarTo<float>(bias[j]);
out[j] = hwy::ConvertScalarTo<OT>((v - mu) * s * var + b);
}
}
template <typename XT, typename WT, typename OT>
void TestLayerNorm(hwy::RandomState& rng) {
constexpr size_t kSize = 128;
VecT vec[kSize];
WeightT weight[kSize];
WeightT bias[kSize];
OutT expected[kSize];
OutT actual[kSize];
XT vec[kSize];
WT weight[kSize];
WT bias[kSize];
OT expected[kSize];
OT actual[kSize];
for (size_t i = 0; i < kSize; ++i) {
vec[i] = hwy::ConvertScalarTo<VecT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
bias[i] = hwy::ConvertScalarTo<WeightT>(RandomGaussian(rng));
vec[i] = hwy::ConvertScalarTo<XT>(RandomGaussian(rng));
weight[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
bias[i] = hwy::ConvertScalarTo<WT>(RandomGaussian(rng));
}
double expected_mu, expected_mu2;
ScalarMus(vec, kSize, expected_mu, expected_mu2);
double actual_mu, actual_mu2;
ComputeMoments(vec, kSize, actual_mu, actual_mu2);
ScalarLayerNorm(vec, weight, bias, expected, kSize);
LayerNorm(vec, weight, bias, actual, kSize);
@ -551,8 +590,8 @@ void TestLayerNorm(hwy::RandomState& rng) {
const float e = hwy::ConvertScalarTo<float>(expected[i]);
const float a = hwy::ConvertScalarTo<float>(actual[i]);
if (!IsNear(e, a, 1e-5f)) {
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<VecT>(),
TypeName<WeightT>(), TypeName<OutT>(), i, e, a);
HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName<XT>(),
TypeName<WT>(), TypeName<OT>(), i, e, a);
}
}
}

View File

@ -55,8 +55,7 @@ PYBIND11_MODULE(configs, py_module) {
.value("kSFP", Type::kSFP)
.value("kNUQ", Type::kNUQ)
.value("kF64", Type::kF64)
.value("kC64", Type::kC64)
.value("kU128", Type::kU128);
.value("kC64", Type::kC64);
enum_<LayerAttentionType>(py_module, "LayerAttentionType")
.value("kGemma", LayerAttentionType::kGemma)

View File

@ -168,9 +168,9 @@ class GemmaModel {
void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) {
const gcpp::Gemma& gemma = *gemma_.GetGemma();
const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator;
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
const gcpp::ModelConfig& config = gemma.GetModelConfig();
if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA &&
config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
throw std::invalid_argument("Not a PaliGemma model.");
}
py::buffer_info buffer = image.request();
@ -182,14 +182,15 @@ class GemmaModel {
float* ptr = static_cast<float*>(buffer.ptr);
gcpp::Image c_image;
c_image.Set(height, width, ptr);
const size_t image_size = gemma.GetModelConfig().vit_config.image_size;
const size_t image_size = config.vit_config.image_size;
c_image.Resize(image_size, image_size);
image_tokens_ = gcpp::ImageTokens(
allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len,
gemma.GetModelConfig().model_dim));
image_tokens_.reset(new gcpp::ImageTokens(
"image_tokens",
gcpp::Extents2D(config.vit_config.seq_len, config.model_dim),
gcpp::MatPadding::kOdd));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0};
gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_);
gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_);
}
// Generates a response to the given prompt, using the last set image.
@ -197,9 +198,7 @@ class GemmaModel {
std::pair<std::string, std::vector<int>> GenerateWithImage(
std::string prompt, size_t max_generated_tokens, float temperature,
float seed, gcpp::AcceptFunc accept, std::vector<int> prompt_tokens) {
if (image_tokens_.Cols() == 0) {
throw std::invalid_argument("No image set.");
}
if (!image_tokens_) throw std::invalid_argument("No image set.");
const gcpp::Gemma& model = *gemma_.GetGemma();
gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
@ -207,7 +206,7 @@ class GemmaModel {
config.temperature = temperature;
config.verbosity = 0;
config.accept_token = accept;
config.image_tokens = &image_tokens_;
config.image_tokens = image_tokens_.get();
std::vector<int> tokens;
if (!prompt_tokens.empty()) {
if (!prompt.empty()) {
@ -219,7 +218,7 @@ class GemmaModel {
} else {
tokens = gemma_.WrapAndTokenize(prompt);
}
tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0);
tokens.insert(tokens.begin(), image_tokens_->Rows(), 0);
size_t num_tokens = tokens.size();
size_t prefix_end = num_tokens;
config.prefill_tbatch_size = num_tokens;
@ -252,7 +251,7 @@ class GemmaModel {
private:
gcpp::GemmaEnv gemma_;
gcpp::ImageTokens image_tokens_;
std::unique_ptr<gcpp::ImageTokens> image_tokens_;
float last_prob_;
};

View File

@ -117,11 +117,11 @@ static size_t Stride(const Allocator& allocator, const MatPtr& mat,
}
}
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) {
const bool is_nuq = mat.GetType() == Type::kNUQ;
const Allocator& allocator = ThreadingContext::Get().allocator;
const size_t stride = Stride(allocator, mat, padding);
const size_t num = mat.Rows() * stride;
const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding);
const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
// might not be enough, hence add extra. `MatT` is at least one byte, which
// is half of BF16, hence adding `VectorBytes` *elements* is enough.

View File

@ -28,7 +28,7 @@
#include "compression/shared.h" // Type
#include "gemma/tensor_info.h"
#include "io/fields.h"
#include "util/allocator.h"
#include "util/allocator.h" // AlignedPtr2
#include "util/basics.h" // Extents2D
// IWYU pragma: end_exports
#include "hwy/base.h"
@ -47,7 +47,7 @@ class MatPtr : public IFields {
// `name`: see `SetName`. Note that `stride` is initially `cols` and only
// differs after deserializing, or calling `SetPtr`.
MatPtr(const char* name, Type type, Extents2D extents)
: rows_(static_cast<uint32_t>(extents.rows)),
: private_rows_(static_cast<uint32_t>(extents.rows)),
cols_(static_cast<uint32_t>(extents.cols)) {
SetName(name);
SetType(type);
@ -74,7 +74,7 @@ class MatPtr : public IFields {
bool HasPtr() const { return ptr_ != nullptr; }
// A single row counts as packed because there is no padding between rows.
bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); }
bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); }
const void* Packed() const {
HWY_DASSERT_M(IsPacked(), name_.c_str());
@ -96,17 +96,17 @@ class MatPtr : public IFields {
// Works for any kind of padding.
template <typename T>
T* MutableRowT(size_t row) const {
HWY_DASSERT(row < rows_);
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
}
template <typename T>
T* RowT(size_t row) {
HWY_DASSERT(row < rows_);
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
}
template <typename T>
const T* RowT(size_t row) const {
HWY_DASSERT(row < rows_);
HWY_DASSERT(row < Rows());
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
}
@ -118,10 +118,22 @@ class MatPtr : public IFields {
HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16);
}
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; }
size_t Rows() const { return rows_; }
size_t Rows() const {
return override_rows_ == 0 ? private_rows_ : override_rows_;
}
size_t Cols() const { return cols_; }
Extents2D Extents() const { return Extents2D(rows_, cols_); }
Extents2D Extents() const { return Extents2D(Rows(), cols_); }
bool IsEmpty() const { return Rows() == 0 || cols_ == 0; }
bool SameShape(const MatPtr& other) const {
return Rows() == other.Rows() && cols_ == other.cols_;
}
// Future calls to `Rows()` during this class' lifetime (not serialized)
// will return this value. Used to set the actual number of rows for
// activations preallocated according to the batch size.
void OverrideRows(size_t rows) {
HWY_ASSERT(rows <= private_rows_);
override_rows_ = static_cast<uint32_t>(rows);
}
// Offset by which to advance pointers to the next row.
size_t Stride() const { return stride_; }
@ -150,7 +162,7 @@ class MatPtr : public IFields {
visitor(type_);
visitor(element_bytes_);
visitor(num_elements_);
visitor(rows_);
visitor(private_rows_);
visitor(cols_);
visitor(scale_);
visitor(stride_);
@ -164,11 +176,11 @@ class MatPtr : public IFields {
// padding, which is anyway not supported for NUQ because `compress-inl.h`
// assumes a contiguous stream for its group indexing.
static size_t ComputeNumElements(Type type, Extents2D extents) {
const size_t num_elements = extents.Area();
size_t num_elements = extents.Area();
if (type == Type::kNUQ) {
// `CompressedArrayElements` is a wrapper function that has the same
// effect, but that requires a template argument, not `type`.
return NuqStream::PackedEnd(num_elements);
num_elements = NuqStream::PackedEnd(num_elements);
}
return num_elements;
}
@ -184,9 +196,10 @@ class MatPtr : public IFields {
// Number of elements to store (including NUQ tables but not padding).
// This a function of `type_` and `Extents()` and stored for compatibility.
uint32_t num_elements_ = 0;
uint32_t rows_ = 0;
uint32_t private_rows_ = 0; // Only access via Rows()! See OverrideRows().
uint32_t cols_ = 0;
float scale_ = 1.0f; // multiplier for each value, for MatMul.
uint32_t override_rows_ = 0; // not serialized
// Non-owning pointer, must not be freed. The underlying memory must outlive
// this object.
@ -194,6 +207,8 @@ class MatPtr : public IFields {
// Offset by which to advance pointers to the next row, >= `cols_`.
uint32_t stride_;
float scale_ = 1.0f; // multiplier for each value, for MatMul.
};
// Non-type erased version of `MatPtr`. Although `MatPtr` also provides
@ -202,6 +217,8 @@ class MatPtr : public IFields {
template <typename MatT>
class MatPtrT : public MatPtr {
public:
using T = MatT;
// Called by `MatStorageT`.
MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {}
@ -253,26 +270,67 @@ class MatPtrT : public MatPtr {
};
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`. Currently unused but may be used after we move toward
// type-erased `WeightsPtrs`.
// optional `args`. This supports all types used as weights, which excludes
// `kC64` and `kF64` (used only in `backprop/`).
template <class Func, typename... Args>
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
decltype(auto) CallUpcasted(const MatPtr* base, const Func& func,
Args&&... args) {
HWY_ASSERT(base != nullptr);
if (type == Type::kF32) {
return func(dynamic_cast<MatPtrT<float>*>(base),
if (base->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base),
std::forward<Args>(args)...);
} else if (type == Type::kBF16) {
return func(dynamic_cast<MatPtrT<BF16>*>(base),
} else if (base->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
std::forward<Args>(args)...);
} else if (type == Type::kSFP) {
return func(dynamic_cast<MatPtrT<SfpStream>*>(base),
} else if (base->GetType() == Type::kSFP) {
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base),
std::forward<Args>(args)...);
} else if (type == Type::kNUQ) {
return func(dynamic_cast<MatPtrT<NuqStream>*>(base),
} else if (base->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Type %d unknown.", static_cast<int>(type));
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
}
// Calls `func(base1, base2, args...)`.
template <class Func, typename... Args>
decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2,
const Func& func, Args&&... args) {
HWY_ASSERT(base1->GetType() == base2->GetType());
if (base1->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base1),
dynamic_cast<const MatPtrT<float>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base1),
dynamic_cast<const MatPtrT<BF16>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kSFP) {
return func(dynamic_cast<const MatPtrT<SfpStream>*>(base1),
dynamic_cast<const MatPtrT<SfpStream>*>(base2),
std::forward<Args>(args)...);
} else if (base1->GetType() == Type::kNUQ) {
return func(dynamic_cast<const MatPtrT<NuqStream>*>(base1),
dynamic_cast<const MatPtrT<NuqStream>*>(base2),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType()));
}
}
// Like CallUpcasted, but only for activation types: kBF16 and kF32.
template <class Func, typename... Args>
decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func,
Args&&... args) {
HWY_ASSERT(base != nullptr);
if (base->GetType() == Type::kF32) {
return func(dynamic_cast<const MatPtrT<float>*>(base),
std::forward<Args>(args)...);
} else if (base->GetType() == Type::kBF16) {
return func(dynamic_cast<const MatPtrT<BF16>*>(base),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Unhandled type %s.", TypeName(base->GetType()));
}
}
@ -362,8 +420,11 @@ class MatStorageT : public MatPtrT<MatT> {
public:
MatStorageT(const char* name, Extents2D extents, MatPadding padding)
: MatPtrT<MatT>(name, extents) {
owner_.AllocateFor(*this, padding);
if (extents.Area() != 0) owner_.AllocateFor(*this, padding);
}
// Shorthand for 1D tensors: packing does not help, hence `kPacked`.
MatStorageT(const char* name, size_t cols)
: MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {}
~MatStorageT() = default;
// Allow move for backprop/activations.
@ -467,81 +528,14 @@ using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>;
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory. Unlike `MatPtr`, this lacks metadata.
// TODO: replace with `MatStorageT`.
// TODO: remove allocator arg once kCyclic is removed.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
// we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(const Allocator& allocator, Extents2D extents,
size_t stride = 0)
: extents_(extents) {
if (stride == 0) {
stride_ = extents_.cols;
} else {
HWY_ASSERT(stride >= extents_.cols);
stride_ = stride;
}
// Allow binding the entire matrix.
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
allocator.QuantumBytes() / sizeof(T));
mem_ = allocator.Alloc<T>(padded);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
size_t Stride() const { return stride_; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
private:
AlignedPtr2<T[]> mem_;
Extents2D extents_;
size_t stride_;
};
template <typename T>
RowPtr<T> RowPtrFromBatch(const Allocator& allocator,
RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
row_vectors.Stride());
}
template <typename T>
RowVectorBatch<T> AllocateAlignedRows(const Allocator& allocator,
Extents2D extents) {
return RowVectorBatch<T>(
allocator, extents,
StrideForCyclicOffsets(extents.cols,
allocator.QuantumBytes() / sizeof(T)));
RowPtr<T> RowPtrFromMat(const Allocator& allocator,
const MatPtrT<T>& row_vectors) {
// RowPtr is non-const for MatMul C, but is also used for A which is const.
// Callers are responsible for checking their usage of RowPtr.
return RowPtr<T>(allocator, const_cast<T*>(row_vectors.Row(0)),
row_vectors.Cols(), row_vectors.Stride());
}
} // namespace gcpp