diff --git a/BUILD.bazel b/BUILD.bazel index 9475ccc..83671e4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -415,6 +415,7 @@ cc_library( hdrs = ["gemma/kv_cache.h"], deps = [ ":configs", + ":mat", "@highway//:hwy", ], ) @@ -425,6 +426,7 @@ cc_library( deps = [ ":args", ":basics", + ":mat", ":ops", # matmul.h "//io", "@highway//:hwy", diff --git a/backprop/activations.h b/backprop/activations.h index d0446cd..7e42032 100644 --- a/backprop/activations.h +++ b/backprop/activations.h @@ -38,8 +38,8 @@ struct ForwardLayer { att_post1(MakePacked("att_post1", seq_len, config.model_dim)), attention_out( MakePacked("attention_out", seq_len, config.model_dim)), - bf_pre_ffw_rms_out( - MakePacked("bf_preFF_rms_out", seq_len, config.model_dim)), + pre_ffw_rms_out( + MakePacked("preFF_rms_out", seq_len, config.model_dim)), ffw_hidden( MakePacked("ffw_hidden", seq_len, config.ff_hidden_dim * 2)), ffw_hidden_gated( @@ -53,7 +53,7 @@ struct ForwardLayer { MatStorageT att_out; MatStorageT att_post1; MatStorageT attention_out; - MatStorageT bf_pre_ffw_rms_out; + MatStorageT pre_ffw_rms_out; MatStorageT ffw_hidden; MatStorageT ffw_hidden_gated; const LayerConfig& layer_config; diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index fbc59e2..783ed0b 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -170,8 +170,7 @@ void LayerVJP(const LayerWeightsPtrs& weights, const ForwardLayer& forward, const float* HWY_RESTRICT next_layer_grad, size_t num_tokens, LayerWeightsPtrs& grad, ForwardLayer& backward, - const RowVectorBatch& inv_timescale, - hwy::ThreadPool& pool) { + const MatStorageT& inv_timescale, hwy::ThreadPool& pool) { const LayerConfig& config = weights.layer_config; const size_t model_dim = config.model_dim; const size_t qkv_dim = config.qkv_dim; @@ -207,15 +206,14 @@ void LayerVJP(const LayerWeightsPtrs& weights, } } - MatMulVJP(weights.gating_einsum_w.Packed(), - forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), - model_dim, ff_hidden_dim * 2, num_tokens, - grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(), - pool); - RMSNormVJP( - weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), - backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens, - grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool); + MatMulVJP(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(), + backward.ffw_hidden.Packed(), model_dim, ff_hidden_dim * 2, + num_tokens, grad.gating_einsum_w.Packed(), + backward.pre_ffw_rms_out.Packed(), pool); + RMSNormVJP(weights.pre_ffw_norm_scale.Packed(), + forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(), + model_dim, num_tokens, grad.pre_ffw_norm_scale.Packed(), + backward.attention_out.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(next_layer_grad + pos * model_dim, @@ -275,7 +273,7 @@ void LayerVJP(const LayerWeightsPtrs& weights, for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim; - Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); + Rope(b_kv, qkv_dim, inv_timescale.PackedScale1(), -pos); } for (size_t head = 0; head < heads; ++head) { @@ -283,7 +281,7 @@ void LayerVJP(const LayerWeightsPtrs& weights, float* HWY_RESTRICT b_q = backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim; MulByConst(query_scale, b_q, qkv_dim); - Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); + Rope(b_q, qkv_dim, inv_timescale.PackedScale1(), -pos); } } @@ -342,7 +340,7 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt, const ForwardPass& forward, ModelWeightsPtrs& grad, ForwardPass& backward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool) { const ModelConfig& config = weights.weights_config; const size_t kVocabSize = config.vocab_size; diff --git a/backprop/backward.cc b/backprop/backward.cc index d89da45..36edf74 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -42,7 +42,7 @@ void CrossEntropyLossBackwardPassT(const Prompt& prompt, const ForwardPass& forward, ModelWeightsPtrs& grad, ForwardPass& backward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool) { CrossEntropyLossBackwardPassInl(prompt, weights, forward, grad, backward, inv_timescale, pool); @@ -62,7 +62,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const ForwardPass& forward, ModelWeightsPtrs& grad, ForwardPass& backward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossBackwardPassT)( prompt, weights, forward, grad, backward, inv_timescale, pool); diff --git a/backprop/backward.h b/backprop/backward.h index 5a08f5c..f8de706 100644 --- a/backprop/backward.h +++ b/backprop/backward.h @@ -29,7 +29,7 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const ForwardPass& forward, ModelWeightsPtrs& grad, ForwardPass& backward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index 20b43ed..b0c7f13 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -218,16 +218,15 @@ void LayerVJP(const LayerWeightsPtrs& weights, GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(), backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens); - MatMulVJPT(weights.gating_einsum_w.Packed(), - forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), - grad.gating_einsum_w.Packed(), - backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim, + MatMulVJPT(weights.gating_einsum_w.Packed(), forward.pre_ffw_rms_out.Packed(), + backward.ffw_hidden.Packed(), grad.gating_einsum_w.Packed(), + backward.pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim, num_tokens); - RMSNormVJPT( - weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), - backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(), - backward.attention_out.Packed(), model_dim, num_tokens); + RMSNormVJPT(weights.pre_ffw_norm_scale.Packed(), + forward.attention_out.Packed(), backward.pre_ffw_rms_out.Packed(), + grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), + model_dim, num_tokens); AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim); diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index 9f3a1d6..5b220ca 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -202,7 +202,7 @@ void TestEndToEnd() { ReverseSequenceSampler training_task({0, 0, 1, 1}); std::vector batch = training_task.SampleBatch(3, gen); - RowVectorBatch inv_timescale = CreateInvTimescale( + MatStorageT inv_timescale = CreateInvTimescale( ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); for (const Prompt& prompt : batch) { diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 0730dbe..954243c 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -74,7 +74,7 @@ void ApplyRMSNorm(const WT* HWY_RESTRICT weights, const XT* HWY_RESTRICT x, hwy::ThreadPool& pool) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t offset = pos * model_dim; - RMSNorm(x + offset, weights, output + offset, model_dim); + RMSNorm(x + offset, weights, 0, output + offset, model_dim); } } @@ -100,7 +100,7 @@ template void ApplyForwardLayer(const LayerWeightsPtrs& weights, ForwardLayer& activations, size_t num_tokens, float* HWY_RESTRICT output, - const RowVectorBatch& inv_timescale, + const MatStorageT& inv_timescale, hwy::ThreadPool& pool) { const LayerConfig& config = weights.layer_config; const size_t model_dim = config.model_dim; @@ -125,14 +125,14 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT k = activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim; - Rope(k, kQKVDim, inv_timescale.Const(), pos); + Rope(k, kQKVDim, inv_timescale.PackedScale1(), pos); } pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; float* HWY_RESTRICT q = activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim; - Rope(q, kQKVDim, inv_timescale.Const(), pos); + Rope(q, kQKVDim, inv_timescale.PackedScale1(), pos); MulByConst(query_scale, q, kQKVDim); }); @@ -194,11 +194,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(), activations.attention_out.Packed(), model_dim, num_tokens, - activations.bf_pre_ffw_rms_out.Packed(), pool); + activations.pre_ffw_rms_out.Packed(), pool); const size_t kFFHiddenDim = config.ff_hidden_dim; for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, - activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim, + activations.pre_ffw_rms_out.Packed() + pos * model_dim, activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { @@ -233,7 +233,7 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, size_t context_size, const ModelWeightsPtrs& weights, ForwardPass& forward, - const RowVectorBatch& inv_timescale, + const MatStorageT& inv_timescale, hwy::ThreadPool& pool) { const ModelConfig& config = weights.weights_config; const size_t vocab_size = config.vocab_size; diff --git a/backprop/forward.cc b/backprop/forward.cc index 8f85e81..c31f359 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -38,7 +38,7 @@ namespace HWY_NAMESPACE { float CrossEntropyLossForwardPassT(const Prompt& prompt, const ModelWeightsPtrs& weights, ForwardPass& forward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool) { return CrossEntropyLossForwardPass(prompt.tokens, prompt.context_size, weights, forward, inv_timescale, pool); @@ -56,7 +56,7 @@ HWY_EXPORT(CrossEntropyLossForwardPassT); float CrossEntropyLossForwardPass(const Prompt& prompt, const ModelWeightsPtrs& weights, ForwardPass& forward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool) { return HWY_DYNAMIC_DISPATCH(CrossEntropyLossForwardPassT)( prompt, weights, forward, inv_timescale, pool); diff --git a/backprop/forward.h b/backprop/forward.h index 3b42298..042d40d 100644 --- a/backprop/forward.h +++ b/backprop/forward.h @@ -19,7 +19,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/weights.h" -#include "util/allocator.h" +#include "util/mat.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -27,7 +27,7 @@ namespace gcpp { float CrossEntropyLossForwardPass(const Prompt& prompt, const ModelWeightsPtrs& weights, ForwardPass& forward, - RowVectorBatch& inv_timescale, + MatStorageT& inv_timescale, hwy::ThreadPool& pool); } // namespace gcpp diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index d81ae30..45d0f18 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -219,12 +219,11 @@ void ApplyLayer(const LayerWeightsPtrs& weights, RMSNormT(weights.pre_ffw_norm_scale.Packed(), activations.attention_out.Packed(), - activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens); + activations.pre_ffw_rms_out.Packed(), model_dim, num_tokens); MatMulT(weights.gating_einsum_w.Packed(), - activations.bf_pre_ffw_rms_out.Packed(), - activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim, - num_tokens); + activations.pre_ffw_rms_out.Packed(), activations.ffw_hidden.Packed(), + ff_hidden_dim * 2, model_dim, num_tokens); GatedGelu(activations.ffw_hidden.Packed(), activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens); diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index b23d404..be1723e 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -62,9 +62,9 @@ TEST(OptimizeTest, GradientDescent) { grad_m.ZeroInit(); grad_v.ZeroInit(); ForwardPass forward(config), backward(config); - KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); + KVCache kv_cache(config, /*prefill_tbatch_size=*/16); - RowVectorBatch inv_timescale = CreateInvTimescale( + MatStorageT inv_timescale = CreateInvTimescale( allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); @@ -147,7 +147,7 @@ TEST(OptimizeTest, GradientDescent) { printf("Num steps: %zu\n", steps); printf("Final weights:\n"); gemma.MutableWeights().LogWeightStatsF32(); - EXPECT_LT(steps, 50); + EXPECT_LT(steps, 80); EXPECT_EQ(num_ok, kBatchSize); } diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 6dd5982..8f1bf91 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -23,6 +23,7 @@ #include #include // std::shuffle +#include #include #include "compression/distortion.h" @@ -104,7 +105,7 @@ struct TestPlateaus { HWY_ASSERT(-0.5f <= in[i] && in[i] < 0.5f); } - std::random_device rd; + std::random_device rd; // NOLINT std::mt19937 rng(rd()); std::shuffle(in.get(), in.get() + kGroupSize, rng); @@ -151,7 +152,7 @@ struct TestRamp { HWY_ASSERT(-0.45f <= in[i] && in[i] < 0.55f); } - std::random_device rd; + std::random_device rd; // NOLINT std::mt19937 rng(rd()); std::shuffle(in.get(), in.get() + kGroupSize, rng); @@ -246,7 +247,8 @@ struct TestOffset { auto in = hwy::AllocateAligned(total); // Enc() requires f32 auto dec1 = hwy::AllocateAligned(total); auto dec2 = hwy::AllocateAligned(kMidLen); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes())); HWY_ASSERT(in && dec1 && dec2 && nuq); const auto nuq_span = MakeSpan(nuq.get(), total); @@ -296,7 +298,8 @@ struct TestUnalignedOffset { auto in = hwy::AllocateAligned(total); // Enc() requires f32 auto dec1 = hwy::AllocateAligned(total); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes())); auto dec2 = hwy::AllocateAligned(num_decompressed); HWY_ASSERT(in && dec1 && dec2 && nuq); const auto nuq_span = MakeSpan(nuq.get(), total); @@ -347,7 +350,8 @@ struct TestDec2 { auto dec0 = hwy::AllocateAligned(total); auto dec1 = hwy::AllocateAligned(total); auto dec2 = hwy::AllocateAligned(kMidLen); - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(total)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(total), hwy::VectorBytes())); HWY_ASSERT(in && dec0 && dec1 && dec2 && nuq); const auto nuq_span = MakeSpan(nuq.get(), total); @@ -449,7 +453,8 @@ struct TestEncDec { const size_t num = 4 * kGroupSize; auto in = hwy::AllocateAligned(num); // Enc() requires f32 auto out = hwy::AllocateAligned(num); // already padded - auto nuq = hwy::AllocateAligned(NuqStream::PackedEnd(num)); + auto nuq = hwy::AllocateAligned( + hwy::RoundUpTo(NuqStream::PackedEnd(num), hwy::VectorBytes())); HWY_ASSERT(in && out && nuq); const auto nuq_span = MakeSpan(nuq.get(), num); diff --git a/compression/shared.h b/compression/shared.h index c5b7ad6..00cab4c 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -164,11 +164,11 @@ constexpr bool IsNuqStream() { // weights for a model, but can be used for other purposes, such as types for // `WeightsPtrs`. When adding a new type that is supported, also // update gemma.cc, weights.*, and add instantiations/new_one.cc. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; +enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64 }; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. -static constexpr const char* kTypeStrings[] = { - "unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"}; +static constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", + "nuq", "f64", "c64"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = {0, @@ -177,8 +177,7 @@ static constexpr size_t kTypeBits[] = {0, 8 * sizeof(SfpStream), 4 /* NuqStream, actually 4.5 */, 8 * sizeof(double), - 8 * sizeof(std::complex), - 8 * sizeof(hwy::uint128_t)}; + 8 * sizeof(std::complex)}; static inline bool EnumValid(Type type) { return static_cast(type) < kNumTypes; @@ -200,8 +199,6 @@ Type TypeEnum() { return Type::kF64; } else if constexpr (hwy::IsSame>()) { return Type::kC64; - } else if constexpr (hwy::IsSame()) { - return Type::kU128; } else { HWY_DASSERT(false); return Type::kUnknown; diff --git a/evals/benchmark.cc b/evals/benchmark.cc index a225d64..c899642 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -73,8 +73,8 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(), - env.MutableConfig().prefill_tbatch_size); + KVCache kv_cache(env.GetGemma()->GetModelConfig(), + env.MutableConfig().prefill_tbatch_size); float entropy = ComputeCrossEntropy( *env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 884e616..3ea3baf 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -52,9 +52,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference) : env_(MakeMatMulEnv(threading_args)), gemma_(loader, env_) { // Only allocate one for starters because GenerateBatch might not be called. - kv_caches_.resize(1); - kv_caches_[0] = - KVCache::Create(gemma_.GetModelConfig(), inference.prefill_tbatch_size); + kv_caches_.push_back( + KVCache(gemma_.GetModelConfig(), inference.prefill_tbatch_size)); InitGenerator(inference, gen_); @@ -131,15 +130,10 @@ std::vector GemmaEnv::BatchQueryModel( runtime_config_.decode_qbatch_size); } - // Ensure we have one KVCache per query. - if (kv_caches_.size() < num_queries) { - kv_caches_.resize(num_queries); - } - for (size_t i = 1; i < num_queries; ++i) { - if (kv_caches_[i].seq_len == 0) { - kv_caches_[i] = KVCache::Create(gemma_.GetModelConfig(), - runtime_config_.prefill_tbatch_size); - } + // Ensure we have at least one KVCache per query. + while (kv_caches_.size() < num_queries) { + kv_caches_.push_back( + KVCache(gemma_.GetModelConfig(), runtime_config_.prefill_tbatch_size)); } gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 5082402..683ff88 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -53,8 +53,7 @@ int main(int argc, char** argv) { // Instantiate model and KV Cache gcpp::MatMulEnv env(MakeMatMulEnv(threading)); gcpp::Gemma gemma(loader, env); - gcpp::KVCache kv_cache = gcpp::KVCache::Create(gemma.GetModelConfig(), - inference.prefill_tbatch_size); + gcpp::KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size); size_t generated = 0; // Initialize random number generator diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 738319b..0899096 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -39,11 +39,8 @@ class SimplifiedGemma { threading_(threading), inference_(inference), env_(MakeMatMulEnv(threading_)), - gemma_(loader_, env_) { - // Instantiate model and KV Cache - kv_cache_ = gcpp::KVCache::Create(gemma_.GetModelConfig(), - inference_.prefill_tbatch_size); - + gemma_(loader_, env_), + kv_cache_(gemma_.GetModelConfig(), inference_.prefill_tbatch_size) { // Initialize random number generator std::random_device rd; gen_.seed(rd()); diff --git a/gemma/activations.h b/gemma/activations.h index 3b77791..d766ef7 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -23,106 +23,127 @@ #include "ops/ops.h" // CreateInvTimescale #include "util/allocator.h" // Allocator #include "util/basics.h" // BF16 -#include "util/mat.h" // RowVectorBatch +#include "util/mat.h" // MatStorageT namespace gcpp { struct Activations { - explicit Activations(const ModelConfig& config) + Activations(const ModelConfig& config, size_t batch_size, MatMulEnv* env) : weights_config(config), layer_config(config.layer_configs[0]), seq_len(config.seq_len), - cache_pos_size(config.CachePosSize()) {} + cache_pos_size(config.CachePosSize()), + is_griffin(layer_config.type == + LayerAttentionType::kGriffinRecurrentBlock), - RowVectorBatch x; // input - RowVectorBatch q; // query, also KV if MHA. - RowVectorBatch logits; + x("x", Extents2D(batch_size, config.model_dim), pad_), + q("q", + Extents2D(batch_size, layer_config.heads * layer_config.QStride()), + pad_), + logits("logits", Extents2D(batch_size, config.vocab_size), pad_), - // Attention - RowVectorBatch pre_att_rms_out; - RowVectorBatch att; // attention vector - RowVectorBatch att_out; // attention output - // Accumulation of attention outputs over heads - RowVectorBatch att_sums; + pre_att_rms_out("pre_att_rms_out", + Extents2D(batch_size, config.model_dim), pad_), + att("att", Extents2D(batch_size, layer_config.heads * config.seq_len), + pad_), + att_out( + "att_out", + Extents2D(batch_size, layer_config.heads * layer_config.qkv_dim), + pad_), + att_sums("att_sums", Extents2D(batch_size, config.model_dim), pad_), - // Gated FFW - RowVectorBatch bf_pre_ffw_rms_out; - RowVectorBatch C1; - RowVectorBatch C2; - RowVectorBatch ffw_out; + pre_ffw_rms_out("pre_ffw_rms_out", + Extents2D(batch_size, config.model_dim), pad_), + C1("C1", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), + C2("C2", Extents2D(batch_size, layer_config.ff_hidden_dim), pad_), + ffw_out("ffw_out", Extents2D(batch_size, config.model_dim), pad_), - // Griffin - RowVectorBatch griffin_x; - RowVectorBatch griffin_y; - RowVectorBatch griffin_gate_x; - RowVectorBatch griffin_multiplier; + // No padding for Griffin because it does not always use Row(). + griffin_x("griffin_x", + is_griffin ? Extents2D(batch_size, config.model_dim) : none_, + MatPadding::kPacked), + griffin_y("griffin_y", + is_griffin ? Extents2D(batch_size, config.model_dim) : none_, + MatPadding::kPacked), + griffin_gate_x( + "griffin_gate_x", + is_griffin ? Extents2D(batch_size, config.model_dim) : none_, + MatPadding::kPacked), + griffin_multiplier( + "griffin_mul", + is_griffin ? Extents2D(batch_size, config.model_dim) : none_, + MatPadding::kPacked), - // Rope - RowVectorBatch inv_timescale; - RowVectorBatch inv_timescale_global; + inv_timescale( + CreateInvTimescale(env->ctx.allocator, layer_config.qkv_dim, + layer_config.post_qk == PostQKType::HalfRope)), + inv_timescale_global(CreateInvTimescale( + env->ctx.allocator, layer_config.qkv_dim, + layer_config.post_qk == PostQKType::HalfRope, 1000000.0)), - // Dynamic because no default ctor and only initialized in `Allocate`. - MatMulEnv* env; + env(env) { + HWY_ASSERT(batch_size != 0); + } + + void SetBatchSize(size_t batch_size) { + x.OverrideRows(batch_size); + q.OverrideRows(batch_size); + logits.OverrideRows(batch_size); + + pre_att_rms_out.OverrideRows(batch_size); + att.OverrideRows(batch_size); + att_out.OverrideRows(batch_size); + att_sums.OverrideRows(batch_size); + + pre_ffw_rms_out.OverrideRows(batch_size); + C1.OverrideRows(batch_size); + C2.OverrideRows(batch_size); + ffw_out.OverrideRows(batch_size); + + if (is_griffin) { + griffin_x.OverrideRows(batch_size); + griffin_y.OverrideRows(batch_size); + griffin_gate_x.OverrideRows(batch_size); + griffin_multiplier.OverrideRows(batch_size); + } + } - PostQKType post_qk = PostQKType::Rope; - // And the config. const ModelConfig& weights_config; const LayerConfig& layer_config; size_t seq_len; - size_t cache_pos_size = 0; + size_t cache_pos_size = 0; // TODO: after moving KVCache to MatStorageT. + bool is_griffin = false; + const Extents2D none_ = Extents2D(); + const MatPadding pad_ = MatPadding::kOdd; - void Allocate(size_t batch_size, MatMulEnv* env) { - const Allocator& allocator = env->ctx.allocator; + MatStorageT x; // input + MatStorageT q; // query, also KV if MHA. + MatStorageT logits; - post_qk = layer_config.post_qk; - const size_t model_dim = weights_config.model_dim; - const size_t ff_hidden_dim = layer_config.ff_hidden_dim; - const size_t vocab_size = weights_config.vocab_size; - const size_t qkv_dim = layer_config.qkv_dim; - const size_t heads = layer_config.heads; + // Attention + MatStorageT pre_att_rms_out; + MatStorageT att; // attention vector + MatStorageT att_out; // attention output + // Accumulation of attention outputs over heads + MatStorageT att_sums; - x = RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - q = RowVectorBatch( - allocator, Extents2D(batch_size, heads * layer_config.QStride())); - if (vocab_size > 0) { - logits = - RowVectorBatch(allocator, Extents2D(batch_size, vocab_size)); - } + // Gated FFW + MatStorageT pre_ffw_rms_out; + MatStorageT C1; + MatStorageT C2; + MatStorageT ffw_out; - pre_att_rms_out = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - att = RowVectorBatch( - allocator, Extents2D(batch_size, heads * weights_config.seq_len)); - att_out = RowVectorBatch(allocator, - Extents2D(batch_size, heads * qkv_dim)); - att_sums = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); + // Griffin + MatStorageT griffin_x; + MatStorageT griffin_y; + MatStorageT griffin_gate_x; + MatStorageT griffin_multiplier; - bf_pre_ffw_rms_out = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - C1 = RowVectorBatch(allocator, Extents2D(batch_size, ff_hidden_dim)); - C2 = RowVectorBatch(allocator, Extents2D(batch_size, ff_hidden_dim)); - ffw_out = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); + // Rope + MatStorageT inv_timescale; + MatStorageT inv_timescale_global; - if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { - griffin_x = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - griffin_y = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - griffin_gate_x = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - griffin_multiplier = - RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - } - - inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim, - post_qk == PostQKType::HalfRope); - inv_timescale_global = CreateInvTimescale( - allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0); - - this->env = env; - } + MatMulEnv* env; }; } // namespace gcpp diff --git a/gemma/bindings/context.cc b/gemma/bindings/context.cc index 38ca070..d81f0b8 100644 --- a/gemma/bindings/context.cc +++ b/gemma/bindings/context.cc @@ -46,8 +46,7 @@ ConversationData::ConversationData(const ModelConfig& model_config, size_t prefill_tbatch_size) : model_config_ref_(model_config), prefill_tbatch_size_(prefill_tbatch_size), - kv_cache(std::make_unique( - KVCache::Create(model_config, prefill_tbatch_size))), + kv_cache(std::make_unique(model_config, prefill_tbatch_size)), abs_pos(0) {} // ConversationData copy constructor implementation @@ -184,25 +183,28 @@ int GemmaContext::GenerateInternal(const char* prompt_string, inference_args.CopyTo(runtime_config); size_t prefix_end = 0; + const ModelConfig& model_config = model.GetModelConfig(); + // generate std::vector prompt; - ImageTokens image_tokens; + const size_t pool_dim = model_config.vit_config.pool_dim; + ImageTokens image_tokens( + "image_tokens", + image_data + ? Extents2D(model_config.vit_config.seq_len / (pool_dim * pool_dim), + model_config.model_dim) + : Extents2D(0, 0), + MatPadding::kOdd); if (image_data != nullptr) { - size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; - image_tokens = - ImageTokens(model.Env().ctx.allocator, - Extents2D(model.GetModelConfig().vit_config.seq_len / - (pool_dim * pool_dim), - model.GetModelConfig().model_dim)); - HWY_ASSERT(model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA || - model.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM); + HWY_ASSERT(model_config.wrapping == PromptWrapping::PALIGEMMA || + model_config.wrapping == PromptWrapping::GEMMA_VLM); Image image; image.Set(image_width, image_height, static_cast(image_data)); // We may need to resize the supplied image depending on whether we're using // PaliGemma or Gemma 3. - const size_t image_size = model.GetModelConfig().vit_config.image_size; + const size_t image_size = model_config.vit_config.image_size; image.Resize(image_size, image_size); // Use the existing runtime_config defined earlier in the function. @@ -217,10 +219,9 @@ int GemmaContext::GenerateInternal(const char* prompt_string, ss << static_cast(image_tokens_duration * 1000) << " ms\n", LogDebug(ss.str().c_str()); - prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), - model.GetModelConfig().wrapping, - active_conversation->abs_pos, prompt_string, - image_tokens.BatchSize()); + prompt = WrapAndTokenize( + model.Tokenizer(), model.ChatTemplate(), model_config.wrapping, + active_conversation->abs_pos, prompt_string, image_tokens.Rows()); runtime_config.image_tokens = &image_tokens; prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. @@ -230,7 +231,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // Text-only case (original logic) // Use abs_pos from the active conversation prompt = WrapAndTokenize(model.Tokenizer(), model.ChatTemplate(), - model.GetModelConfig().wrapping, + model_config.wrapping, active_conversation->abs_pos, prompt_string); prompt_size = prompt.size(); } @@ -251,7 +252,7 @@ int GemmaContext::GenerateInternal(const char* prompt_string, // prepare for next turn if (!inference_args.multiturn || - model.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) { + model_config.wrapping == PromptWrapping::PALIGEMMA) { // If not multiturn, or Paligemma (which handles turns differently), // reset the *active* conversation's position. active_conversation->abs_pos = 0; diff --git a/gemma/bindings/context.h b/gemma/bindings/context.h index ba44c1b..f3b295f 100644 --- a/gemma/bindings/context.h +++ b/gemma/bindings/context.h @@ -188,8 +188,8 @@ class GemmaContext { // rewind to initial state. active_conversation->abs_pos = 0; // Replace the cache within the current ConversationData object - active_conversation->kv_cache = std::make_unique(KVCache::Create( - model.GetModelConfig(), inference_args.prefill_tbatch_size)); + active_conversation->kv_cache = std::make_unique( + model.GetModelConfig(), inference_args.prefill_tbatch_size); LogDebug((log_prefix + "Successfully rewound to initial state.").c_str()); } else { diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 3345ecb..009d5a5 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -89,11 +89,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, // X / Y linear layers. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); + float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx); + float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); TwoMatVecAdd(layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, - activations.pre_att_rms_out.Batch(batch_idx), + activations.pre_att_rms_out.Row(batch_idx), /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*out0=*/x, /*out1=*/y, pool); @@ -103,17 +103,16 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, // Conv1D. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); + float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); HWY_FULL(float) df; HWY_DASSERT(model_dim % hn::Lanes(df) == 0); - const size_t layer_offset = layer * model_dim * (conv_1d_width - 1); // cache[i] = input at time t-i. float* HWY_RESTRICT cache[kMaxConv1DWidth]; cache[0] = x; for (size_t i = 1; i < conv_1d_width; i++) { cache[i] = - kv_cache.conv1d_cache.get() + layer_offset + + kv_cache.conv1d_cache.Row(layer) + ((pos + conv_1d_width - 1 - i) % (conv_1d_width - 1)) * model_dim; } for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { @@ -140,12 +139,11 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, // RGLRU for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { const size_t pos = batch_start + batch_idx; - float* HWY_RESTRICT y = activations.griffin_y.Batch(batch_idx); - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Batch(batch_idx); - float* HWY_RESTRICT a = activations.griffin_multiplier.Batch(batch_idx); - float* HWY_RESTRICT rnn_state = - kv_cache.rglru_cache.get() + layer * model_dim; + float* HWY_RESTRICT y = activations.griffin_y.Row(batch_idx); + float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); + float* HWY_RESTRICT gate_x = activations.griffin_gate_x.Row(batch_idx); + float* HWY_RESTRICT a = activations.griffin_multiplier.Row(batch_idx); + float* HWY_RESTRICT rnn_state = kv_cache.rglru_cache.Row(layer); pool.Run(0, heads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR { const size_t kHeadDim = model_dim / heads; @@ -193,8 +191,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, // Final linear layer. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { - float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); - float* out_ptr = activations.att_sums.Batch(batch_idx); + float* HWY_RESTRICT x = activations.griffin_x.Row(batch_idx); + float* out_ptr = activations.att_sums.Row(batch_idx); MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr, pool); @@ -217,7 +215,7 @@ class GemmaAttention { const float mul) { // qk is either q or k, so qkv_dim is the length we operate on. const size_t qkv_dim = layer_config_.qkv_dim; - const float* inv_timescale = activations_.inv_timescale.Const(); + const float* inv_timescale = activations_.inv_timescale.Packed(); bool is_global_layer = activations_.weights_config.attention_window_sizes[layer] == activations_.seq_len; @@ -227,7 +225,7 @@ class GemmaAttention { activations_.weights_config.model == Model::GEMMA3_12B || activations_.weights_config.model == Model::GEMMA3_27B || activations_.weights_config.model == Model::GEMMA3_1B)) { - inv_timescale = activations_.inv_timescale_global.Const(); + inv_timescale = activations_.inv_timescale_global.Packed(); } // PostQKType::Rope (void)layer; @@ -249,11 +247,10 @@ class GemmaAttention { const size_t heads = layer_config_.heads; const size_t kv_heads = layer_config_.kv_heads; - const auto pre_att_rms_out = - ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out); - auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr() - ? ConstMatFromWeights(layer_weights_.qkv_einsum_w) - : ConstMatFromWeights(layer_weights_.qkv_einsum_w1); + using WeightT = typename decltype(layer_weights_.qkv_einsum_w)::T; + ConstMat w_q1(layer_weights_.qkv_einsum_w.HasPtr() + ? layer_weights_.qkv_einsum_w + : layer_weights_.qkv_einsum_w1); // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, // model_dim], which we reshaped to (heads + kv_heads * 2) * kKQVDim rows. // We must shrink to the actual size because MatMul verifies @@ -262,20 +259,19 @@ class GemmaAttention { // computed in the second MatMul. const size_t w1_rows = heads * layer_config_.QStride(); w_q1.ShrinkRows(w1_rows); - MatMul(pre_att_rms_out, w_q1, + MatMul(activations_.pre_att_rms_out, w_q1, /*add=*/nullptr, *activations_.env, - RowPtrFromBatch(allocator_, activations_.q)); + RowPtrFromMat(allocator_, activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. } else { - decltype(w_q1) w_q2; + decltype(w_q1) w_q2(layer_weights_.qkv_einsum_w.HasPtr() + ? layer_weights_.qkv_einsum_w + : layer_weights_.qkv_einsum_w2); if (layer_weights_.qkv_einsum_w.HasPtr()) { - w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w); // Skip first half of the matrix. w_q2.ofs = w_q2.Row(w1_rows); - } else { - w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2); } // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; @@ -290,13 +286,13 @@ class GemmaAttention { float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols); kv_rows.SetStride(cache_pos_size_); - MatMul(pre_att_rms_out, w_q2, + MatMul(activations_.pre_att_rms_out, w_q2, /*add=*/nullptr, *activations_.env, kv_rows); } else { // Proceed row by row because there will be wraparound. for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved; ++interleaved_idx) { - const float* x = activations_.pre_att_rms_out.Batch(interleaved_idx); + const float* x = activations_.pre_att_rms_out.Row(interleaved_idx); const size_t query_idx = interleaved_idx % num_queries_; const size_t batch_idx = interleaved_idx / num_queries_; KVCache& kv_cache = kv_caches_[query_idx]; @@ -327,15 +323,15 @@ class GemmaAttention { // If MHA, copy computed K and V into KVCache. if (is_mha_) { const float* HWY_RESTRICT mha_kv = - activations_.q.Batch(interleaved_idx) + head * q_stride_ + + activations_.q.Row(interleaved_idx) + head * q_stride_ + qkv_dim; hwy::CopyBytes(mha_kv, kv, 2 * qkv_dim * sizeof(*kv)); } // Apply further processing to K. if (layer_weights_.key_norm_scale.HasPtr()) { - RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv, - qkv_dim); + RMSNormInplace(layer_weights_.key_norm_scale.PackedScale1(), + 0, kv, qkv_dim); } PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); }); @@ -402,7 +398,8 @@ class GemmaAttention { // Apply rope and scaling to Q. if (layer_weights_.query_norm_scale.HasPtr()) { - RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim); + RMSNormInplace(layer_weights_.query_norm_scale.PackedScale1(), 0, q, + qkv_dim); } PositionalEncodingQK(q, pos, layer_, query_scale); @@ -435,13 +432,12 @@ class GemmaAttention { const size_t head_offset = (head / kHeadGroups) * qkv_dim * 2; float* HWY_RESTRICT q = - activations_.q.Batch(interleaved_idx) + head * q_stride_; + activations_.q.Row(interleaved_idx) + head * q_stride_; float* HWY_RESTRICT att = - activations_.att.Batch(interleaved_idx) + + activations_.att.Row(interleaved_idx) + head * activations_.seq_len; float* HWY_RESTRICT att_out = - activations_.att_out.Batch(interleaved_idx) + - head * qkv_dim; + activations_.att_out.Row(interleaved_idx) + head * qkv_dim; // Make strided views into the kv cache entries for the current // query and head. @@ -476,28 +472,25 @@ class GemmaAttention { private: // Sums encoded (`att_out`) over num_heads (`layer_config_.heads`) and // head_dim (`qkv_dim`) into output (`layer_out`). - HWY_NOINLINE void SumHeads(const size_t num_interleaved) { + HWY_NOINLINE void SumHeads() { PROFILER_ZONE("Gen.Attention.SumHeads"); // att_weights and att_out are concatenated heads, each of length // layer_config_.qkv_dim. Thus the [num_interleaved, // layer_config_.model_dim] matmul output is the sum over heads. Compare // gemma/modules.py: attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', // encoded) - HWY_DASSERT(layer_config_.model_dim > 0); - HWY_DASSERT(layer_config_.heads > 0); - HWY_DASSERT(layer_config_.qkv_dim > 0); + HWY_DASSERT(layer_config_.model_dim != 0 && layer_config_.heads != 0 && + layer_config_.qkv_dim != 0); HWY_DASSERT(layer_weights_.att_weights.HasPtr()); - HWY_DASSERT(activations_.att_out.All() != nullptr); - HWY_DASSERT(activations_.att_sums.All() != nullptr); + HWY_DASSERT(activations_.att_out.HasPtr()); + HWY_DASSERT(activations_.att_sums.HasPtr()); const float* add = layer_weights_.layer_config.softmax_attn_output_biases ? layer_weights_.attention_output_biases.PackedScale1() : nullptr; - MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out), - ConstMatFromWeights(layer_weights_.att_weights), add, - *activations_.env, - RowPtrFromBatch(allocator_, activations_.att_sums)); + MatMul(activations_.att_out, layer_weights_.att_weights, add, + *activations_.env, RowPtrFromMat(allocator_, activations_.att_sums)); } public: @@ -524,7 +517,7 @@ class GemmaAttention { const size_t num_interleaved = num_tokens_ * num_queries_; ComputeQKV(num_interleaved); DotSoftmaxWeightedSum(num_interleaved); - SumHeads(num_interleaved); + SumHeads(); } private: @@ -618,12 +611,11 @@ class VitAttention { HWY_NOINLINE void ComputeQKV() { PROFILER_ZONE("Gen.VitAttention.QKV"); auto& qkv = activations_.q; - HWY_ASSERT(qkv.BatchSize() == num_tokens_); + HWY_ASSERT(qkv.Rows() == num_tokens_); HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); - MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out), - ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w), + MatMul(activations_.pre_att_rms_out, layer_weights_.vit.qkv_einsum_w, layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, - RowPtrFromBatch(allocator_, qkv)); + RowPtrFromMat(allocator_, qkv)); } // TODO(philculliton): transition fully to MatMul. @@ -635,52 +627,49 @@ class VitAttention { const float query_scale = 1.0f / sqrtf(static_cast(qkv_dim)); PROFILER_ZONE("Gen.VitAttention.DotSoftmax"); - // Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents) - RowVectorBatch Q = - AllocateAlignedRows(allocator_, Extents2D(num_tokens_, qkv_dim)); - RowVectorBatch K = - AllocateAlignedRows(allocator_, Extents2D(seq_len, qkv_dim)); - RowVectorBatch C(allocator_, Extents2D(num_tokens_, seq_len)); + // Shift Q, K, VT to MatStorageT. + MatStorageT Q("Q2", Extents2D(num_tokens_, qkv_dim), + MatPadding::kPacked); + MatStorageT K("K2", Extents2D(seq_len, qkv_dim), + MatPadding::kPacked); + MatStorageT C("C2", Extents2D(num_tokens_, seq_len), + MatPadding::kPacked); // Initialize att_out to zero prior to head loop. - hwy::ZeroBytes(activations_.att_out.All(), - num_tokens_ * heads * qkv_dim * sizeof(float)); + ZeroInit(activations_.att_out); for (size_t head = 0; head < heads; ++head) { pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t token = task; - float* HWY_RESTRICT q = - activations_.q.Batch(token) + head * 3 * qkv_dim; + float* HWY_RESTRICT q = activations_.q.Row(token) + head * 3 * qkv_dim; // TODO: shift to MatMul with A.scale once MatMul is confirmed working MulByConst(query_scale, q, qkv_dim); - hwy::CopyBytes(q, Q.Batch(token), qkv_dim * sizeof(float)); + hwy::CopyBytes(q, Q.Row(token), qkv_dim * sizeof(float)); }); pool_.Run(0, seq_len, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { const size_t seq_idx = task; float* HWY_RESTRICT k = - activations_.q.Batch(seq_idx) + head * 3 * qkv_dim + qkv_dim; - hwy::CopyBytes(k, K.Batch(seq_idx), qkv_dim * sizeof(float)); + activations_.q.Row(seq_idx) + head * 3 * qkv_dim + qkv_dim; + hwy::CopyBytes(k, K.Row(seq_idx), qkv_dim * sizeof(float)); }); // this produces C, a (num_tokens_, seq_len) matrix of dot products - MatMul(ConstMatFromBatch(Q.BatchSize(), Q), - ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env, - RowPtrFromBatch(allocator_, C)); + MatMul(Q, K, nullptr, *activations_.env, RowPtrFromMat(allocator_, C)); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { - float* HWY_RESTRICT c = C.Batch(task); + float* HWY_RESTRICT c = C.Row(task); Softmax(c, C.Cols()); }); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { size_t token = task; float* HWY_RESTRICT att_out = - activations_.att_out.Batch(token) + head * qkv_dim; + activations_.att_out.Row(token) + head * qkv_dim; for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT v = - activations_.q.Batch(i) + head * 3 * qkv_dim + 2 * qkv_dim; - MulByConstAndAdd(C.Batch(token)[i], v, att_out, qkv_dim); + activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; + MulByConstAndAdd(C.Row(token)[i], v, att_out, qkv_dim); } }); } @@ -701,24 +690,24 @@ class VitAttention { const size_t token = task / layer_config_.heads; // Compute Q.K scores, which are "logits" stored in head_att. float* HWY_RESTRICT q = - activations_.q.Batch(token) + head * 3 * qkv_dim; + activations_.q.Row(token) + head * 3 * qkv_dim; MulByConst(query_scale, q, qkv_dim); float* HWY_RESTRICT head_att = - activations_.att.Batch(token) + head * activations_.seq_len; + activations_.att.Row(token) + head * activations_.seq_len; for (size_t i = 0; i < seq_len; ++i) { float* HWY_RESTRICT k = - activations_.q.Batch(i) + head * 3 * qkv_dim + qkv_dim; + activations_.q.Row(i) + head * 3 * qkv_dim + qkv_dim; head_att[i] = Dot(q, k, qkv_dim); // score = q.k } // SoftMax yields "probabilities" in head_att. Softmax(head_att, seq_len); // Compute weighted sum of v into att_out. float* HWY_RESTRICT att_out = - activations_.att_out.Batch(token) + head * qkv_dim; + activations_.att_out.Row(token) + head * qkv_dim; hwy::ZeroBytes(att_out, qkv_dim * sizeof(*att_out)); for (size_t i = 0; i < seq_len; ++i) { - float* HWY_RESTRICT v = activations_.q.Batch(i) + - head * 3 * qkv_dim + 2 * qkv_dim; + float* HWY_RESTRICT v = + activations_.q.Row(i) + head * 3 * qkv_dim + 2 * qkv_dim; MulByConstAndAdd(head_att[i], v, att_out, qkv_dim); } }); @@ -732,10 +721,9 @@ class VitAttention { // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. - auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); - auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); - auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums); - MatMul(att_out, att_weights, bias, *activations_.env, att_sums); + auto att_sums = RowPtrFromMat(allocator_, activations_.att_sums); + MatMul(activations_.att_out, layer_weights_.vit.attn_out_w, bias, + *activations_.env, att_sums); } public: @@ -771,7 +759,7 @@ class VitAttention { template HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, - T* HWY_RESTRICT c2, size_t count) { + const T* HWY_RESTRICT c2, size_t count) { PROFILER_ZONE("Gen.Activation"); namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -787,12 +775,38 @@ HWY_NOINLINE void Activation(ActivationType activation, T* HWY_RESTRICT c1, }); } +// No C2 multiplier. +template +void ActivationBatched(ActivationType activation, Mat& c1) { + using T = typename Mat::T; + for (size_t i = 0; i < c1.Rows(); ++i) { + // Cast to correct type so type deduction works. + Activation(activation, c1.Row(i), static_cast(nullptr), + c1.Cols()); + } +} + +template +void ActivationBatched(ActivationType activation, Mat& c1, const Mat* c2) { + using T = typename Mat::T; + HWY_DASSERT(c1.SameShape(*c2)); + if (c2 && c2->HasPtr()) { + for (size_t i = 0; i < c1.Rows(); ++i) { + Activation(activation, c1.Row(i), c2->Row(i), c1.Cols()); + } + } else { // No multiplier + for (size_t i = 0; i < c1.Rows(); ++i) { + Activation(activation, c1.Row(i), static_cast(nullptr), + c1.Cols()); + } + } +} + template -HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, +HWY_NOINLINE void FFWNoVit(Activations& activations, const LayerWeightsPtrs* layer_weights) { PROFILER_ZONE("Gen.FFW"); const size_t ffh_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); const bool add_bias = layer_weights->layer_config.ff_biases; const float* bias1 = @@ -802,56 +816,48 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; // Define slightly more readable names for the weights and activations. - const auto x = - ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); - const Allocator& allocator = activations.env->ctx.allocator; - auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); - auto multiplier = RowPtrFromBatch(allocator, activations.C2); - auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); + auto hidden_activations = RowPtrFromMat(allocator, activations.C1); + auto multiplier = RowPtrFromMat(allocator, activations.C2); + auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out); + + using WeightT = typename decltype(layer_weights->gating_einsum_w)::T; // gating_einsum_w holds two half-matrices. We plan to change the importer to // avoid this confusion by splitting into gating_einsum_w1 and - // gating_einsum_w2. + // gating_einsum_w2. TODO: move into Reshape(). const bool split = layer_weights->gating_einsum_w.HasPtr(); - auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w) - : ConstMatFromWeights(layer_weights->gating_einsum_w1); - decltype(w1) w2; + ConstMat w1(split ? layer_weights->gating_einsum_w + : layer_weights->gating_einsum_w1); + ConstMat w2(split ? layer_weights->gating_einsum_w + : layer_weights->gating_einsum_w2); if (split) { - w2 = ConstMatFromWeights(layer_weights->gating_einsum_w); w2.ofs = w2.Row(ffh_hidden_dim); // Ensure that B.Extents().row matches C.Cols() because MatMul checks that. w1.ShrinkRows(ffh_hidden_dim); w2.ShrinkRows(ffh_hidden_dim); - } else { - w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2); } - auto w_output = ConstMatFromWeights(layer_weights->linear_w); // Compute the hidden layer activations. - MatMul(x, w1, bias1, *activations.env, hidden_activations); - MatMul(x, w2, bias2, *activations.env, multiplier); + MatMul(activations.pre_ffw_rms_out, w1, bias1, *activations.env, + hidden_activations); + MatMul(activations.pre_ffw_rms_out, w2, bias2, *activations.env, multiplier); // Activation (Gelu) and maybe multiply by gate. Store activations in act. - Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), - multiplier.Row(0), ffh_hidden_dim * num_interleaved); + ActivationBatched(layer_weights->layer_config.activation, activations.C1, + &activations.C2); // Hidden layer -> output layer. - auto activations_mat = MakeConstMat( - hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim), - hidden_activations.Stride()); - - MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); + MatMul(activations.C1, layer_weights->linear_w, output_bias, *activations.env, + ffw_out); } // Same as FFWNoVit, but with different layer_weights members and no second // gating matrix. template -HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, +HWY_NOINLINE void FFWVit(Activations& activations, const LayerWeightsPtrs* layer_weights) { - PROFILER_ZONE("Gen.FFW"); - const size_t ff_hidden_dim = layer_weights->layer_config.ff_hidden_dim; - HWY_DASSERT(num_interleaved <= activations.bf_pre_ffw_rms_out.BatchSize()); + PROFILER_ZONE("Gen.FFW.ViT"); const bool add_bias = layer_weights->layer_config.ff_biases; const float* bias1 = @@ -860,30 +866,21 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; // Define slightly more readable names for the weights and activations. - const auto x = - ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); - const Allocator& allocator = activations.env->ctx.allocator; - auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); - auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); - - auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w); - auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w); + auto hidden_activations = RowPtrFromMat(allocator, activations.C1); + auto ffw_out = RowPtrFromMat(allocator, activations.ffw_out); // Compute the hidden layer activations. - MatMul(x, w1, bias1, *activations.env, hidden_activations); + MatMul(activations.pre_ffw_rms_out, layer_weights->vit.linear_0_w, bias1, + *activations.env, hidden_activations); // Activation (Gelu), store in act. RowPtrF multiplier = RowPtrF(allocator, nullptr, 0); - Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), - multiplier.Row(0), ff_hidden_dim * num_interleaved); + ActivationBatched(layer_weights->layer_config.activation, activations.C1); // Hidden layer -> output layer. - auto activations_mat = MakeConstMat(hidden_activations.Row(0), - Extents2D(num_interleaved, ff_hidden_dim), - hidden_activations.Stride()); - - MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out); + MatMul(activations.C1, layer_weights->vit.linear_1_w, output_bias, + *activations.env, ffw_out); } // `batch_idx` indicates which row of `x` to write to. @@ -898,23 +895,23 @@ template HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, const ModelWeightsPtrs& weights, - RowVectorBatch& x, + MatStorageT& x, const ImageTokens* image_tokens, size_t& image_token_position) { // Image tokens just need to be copied. if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM && image_tokens != nullptr && token == -2 && - image_token_position < image_tokens->BatchSize()) { - hwy::CopyBytes(image_tokens->Batch(image_token_position), - x.Batch(batch_idx), x.Cols() * sizeof(x.Const()[0])); + image_token_position < image_tokens->Rows()) { + hwy::CopyBytes(image_tokens->Row(image_token_position), x.Row(batch_idx), + x.Cols() * x.ElementBytes()); image_token_position++; return; } if (weights.weights_config.wrapping == PromptWrapping::PALIGEMMA && - image_tokens != nullptr && pos_in_prompt < image_tokens->BatchSize()) { - hwy::CopyBytes(image_tokens->Batch(pos_in_prompt), x.Batch(batch_idx), - x.Cols() * sizeof(x.Const()[0])); + image_tokens != nullptr && pos_in_prompt < image_tokens->Rows()) { + hwy::CopyBytes(image_tokens->Row(pos_in_prompt), x.Row(batch_idx), + x.Cols() * x.ElementBytes()); return; } @@ -934,12 +931,12 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, HWY_ASSERT(weights.embedder_input_embedding.Cols() == model_dim); const auto embedding_span = MakeSpan(weights.embedder_input_embedding.Row(0), embedding_ofs + model_dim); - DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Batch(batch_idx), + DecompressAndZeroPad(df, embedding_span, embedding_ofs, x.Row(batch_idx), model_dim); MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(), - x.Batch(batch_idx), model_dim); + x.Row(batch_idx), model_dim); if (weights.weights_config.absolute_pe) { - AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos); + AddAbsolutePositionalEmbeddings(x.Row(batch_idx), model_dim, pos); } } @@ -951,29 +948,28 @@ template HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos, size_t pos_in_prompt, const ModelWeightsPtrs& weights, - RowVectorBatch& x, + MatStorageT& x, const ImageTokens* image_tokens) { size_t image_token_position = 0; EmbedMMToken(token, batch_idx, pos, pos_in_prompt, weights, x, image_tokens, image_token_position); } -template -HWY_NOINLINE void ResidualConnection( - size_t num_interleaved, const T* HWY_RESTRICT other, T* HWY_RESTRICT x, - const LayerWeightsPtrs* layer_weights, bool is_attention) { +template +HWY_NOINLINE void ResidualConnection(const MatPtrT& other, + MatPtrT& HWY_RESTRICT x, + const LayerWeights* layer_weights, + bool is_attention) { // ResidualType::Add - AddFromBatched(num_interleaved, other, x, - layer_weights->layer_config.model_dim); + AddFromBatched(other, x); } template -void PostNorm(PostNormType post_norm, size_t num_interleaved, - const WeightT& weights, InOutT* inout) { +void PostNorm(PostNormType post_norm, const MatPtrT& weights, + MatPtrT& inout) { HWY_DASSERT(weights.Rows() == 1); if (post_norm == PostNormType::Scale) { - RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout, - weights.Cols()); + RMSNormInplaceBatched(weights, inout); } } @@ -985,39 +981,33 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, Activations& activations, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches) { - const size_t model_dim = activations.weights_config.model_dim; - const size_t num_interleaved = num_tokens * queries_pos.size(); auto type = layer_weights->layer_config.type; - RMSNormBatched(num_interleaved, activations.x.All(), - layer_weights->pre_attention_norm_scale.PackedScale1(), - activations.pre_att_rms_out.All(), model_dim); + RMSNormBatched(activations.x, layer_weights->pre_attention_norm_scale, + activations.pre_att_rms_out); Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx, activations, layer_weights, div_seq_len, kv_caches); - PostNorm(layer_weights->layer_config.post_norm, num_interleaved, - layer_weights->post_attention_norm_scale, - activations.att_sums.All()); + PostNorm(layer_weights->layer_config.post_norm, + layer_weights->post_attention_norm_scale, activations.att_sums); - ResidualConnection(num_interleaved, activations.att_sums.All(), - activations.x.All(), layer_weights, /*is_attention=*/true); + ResidualConnection(activations.att_sums, activations.x, layer_weights, + /*is_attention=*/true); - RMSNormBatched(num_interleaved, activations.x.All(), - layer_weights->pre_ffw_norm_scale.PackedScale1(), - activations.bf_pre_ffw_rms_out.All(), model_dim); + RMSNormBatched(activations.x, layer_weights->pre_ffw_norm_scale, + activations.pre_ffw_rms_out); if (layer_weights->layer_config.type == LayerAttentionType::kVit) { - FFWVit(activations, num_interleaved, layer_weights); + FFWVit(activations, layer_weights); } else { - FFWNoVit(activations, num_interleaved, layer_weights); + FFWNoVit(activations, layer_weights); } - PostNorm(layer_weights->layer_config.post_norm, num_interleaved, - layer_weights->post_ffw_norm_scale, activations.ffw_out.All()); + PostNorm(layer_weights->layer_config.post_norm, + layer_weights->post_ffw_norm_scale, activations.ffw_out); - ResidualConnection(num_interleaved, activations.ffw_out.All(), - activations.x.All(), layer_weights, + ResidualConnection(activations.ffw_out, activations.x, layer_weights, /*is_attention=*/false); } @@ -1034,38 +1024,37 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, auto type = layer_weights->layer_config.type; HWY_DASSERT(type == LayerAttentionType::kVit); (void)type; + (void)model_dim; auto& x = activations.x; - HWY_DASSERT(x.BatchSize() == num_tokens); + HWY_DASSERT(x.Rows() == num_tokens); HWY_DASSERT(x.Cols() == model_dim); // y = nn.LayerNorm()(x) // y ~ pre_att_rms_out - LayerNormBatched(num_tokens, x.All(), - layer_weights->vit.layer_norm_0_scale.PackedScale1(), - layer_weights->vit.layer_norm_0_bias.PackedScale1(), - activations.pre_att_rms_out.All(), model_dim); + LayerNormBatched(x, layer_weights->vit.layer_norm_0_scale, + layer_weights->vit.layer_norm_0_bias, + activations.pre_att_rms_out); // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) // y ~ att_sums VitAttention(num_tokens, layer, activations, layer_weights)(); // x = out["+sa"] = x + y - AddFromBatched(num_tokens, activations.att_sums.All(), x.All(), model_dim); + AddFromBatched(activations.att_sums, x); // y = nn.LayerNorm()(x) - // y ~ bf_pre_ffw_rms_out - LayerNormBatched(num_tokens, x.All(), - layer_weights->vit.layer_norm_1_scale.PackedScale1(), - layer_weights->vit.layer_norm_1_bias.PackedScale1(), - activations.bf_pre_ffw_rms_out.All(), model_dim); + // y ~ pre_ffw_rms_out + LayerNormBatched(x, layer_weights->vit.layer_norm_1_scale, + layer_weights->vit.layer_norm_1_bias, + activations.pre_ffw_rms_out); // y = out["mlp"] = MlpBlock(...)(y) // y ~ ffw_out - FFWVit(activations, num_tokens, layer_weights); + FFWVit(activations, layer_weights); // x = out["+mlp"] = x + y - AddFromBatched(num_tokens, activations.ffw_out.All(), x.All(), model_dim); + AddFromBatched(activations.ffw_out, x); } // Prefill() and Transformer() increment positions in-place. @@ -1094,7 +1083,7 @@ HWY_NOINLINE void Prefill( // intensity, and so we are eventually compute-limited. We could devote some // threads to parallelizing over queries, but for simplicity we assign them // all to MatMul. - const size_t max_tbatch_size = activations.x.BatchSize(); + const size_t max_tbatch_size = activations.x.Rows(); // For each query. `qi` is within the batch, not the global query index. for (size_t qi = 0; qi < num_queries; ++qi) { @@ -1131,6 +1120,7 @@ HWY_NOINLINE void Prefill( tbatch_start += max_tbatch_size) { const size_t tbatch_size = HWY_MIN(max_tbatch_size, prefill_this_query - tbatch_start); + activations.SetBatchSize(tbatch_size); // Fill activations.x (much faster than TransformerLayer). size_t image_token_position = 0; @@ -1201,13 +1191,14 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // H x W x C x D transposed to D x (H x W x C) so here (1152, 14 * 14 * 3) // image_patches is (256, 14 * 14 * 3) // This could be done as one MatMul like: - // RowVectorBatch image_patches(kSeqLen, kPatchSize); + // MatStorageT image_patches("patches", Extents2D(kSeqLen, + // kPatchSize), MatPadding::kPacked); // [Get patches] // MatMul( // MatFromBatch(kVitSeqLen, image_patches), // MatFromWeights(weights.vit_img_embedding_kernel), // weights.vit_img_embedding_bias.PackedScale1(), *activations.env, - // RowPtrF(activations.x.All(), kVitModelDim)); + // RowPtrF(activations.x.Row(0), kVitModelDim)); // However, MatMul currently requires that // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 // which is not the case here. We should relax that requirement on MatMul and @@ -1216,11 +1207,10 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, image_patches[i].get(), weights.vit_img_embedding_bias.PackedScale1(), - activations.x.Batch(i), activations.env->ctx.pools.Pool(0)); + activations.x.Row(i), activations.env->ctx.pools.Pool(0)); } // Add position embeddings. - AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(), - seq_len * model_dim); + AddFromBatched(weights.vit_img_pos_embedding, activations.x); } // Prefills the image tokens with the ViT encoder. @@ -1232,7 +1222,7 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, PROFILER_ZONE("Gen.PrefillVit"); const size_t num_tokens = weights.weights_config.vit_config.seq_len; const size_t vit_model_dim = weights.weights_config.vit_config.model_dim; - HWY_ASSERT(num_tokens == activations.x.BatchSize()); + HWY_ASSERT(num_tokens == activations.x.Rows()); // Embed the image patches. EmbedImagePatches(image, weights, activations); // Go through all layers. @@ -1243,24 +1233,21 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, VitTransformerLayer(num_tokens, layer, layer_weights, activations); } // Final Layernorm. - LayerNormBatched(num_tokens, activations.x.All(), - weights.vit_encoder_norm_scale.PackedScale1(), - weights.vit_encoder_norm_bias.PackedScale1(), - activations.x.All(), vit_model_dim); + LayerNormBatched(activations.x, weights.vit_encoder_norm_scale, + weights.vit_encoder_norm_bias, activations.x); if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { activations.x = AvgPool4x4(activations.x); // Apply soft embedding norm before input projection. - RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(), - vit_model_dim); + RMSNormInplace(weights.mm_embed_norm.PackedScale1(), 0, + activations.x.Row(0), vit_model_dim); } // Apply head embedding into image_tokens of size of the LLM kModelDim. - MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x), - ConstMatFromWeights(weights.vit_img_head_kernel), + MatMul(activations.x, weights.vit_img_head_kernel, weights.vit_img_head_bias.PackedScale1(), *activations.env, - RowPtrFromBatch(activations.env->ctx.allocator, image_tokens)); + RowPtrFromMat(activations.env->ctx.allocator, image_tokens)); } // Generates one token for each query. `queries_token` is the previous token @@ -1272,7 +1259,6 @@ HWY_NOINLINE void Transformer( Activations& activations, const hwy::Divisor& div_seq_len, const KVCaches& kv_caches, const LayersOutputFunc& layers_output, const ActivationsObserverFunc& activations_observer) { - const size_t model_dim = weights.weights_config.model_dim; const size_t num_queries = queries_token.size(); HWY_DASSERT(queries_pos.size() == num_queries); HWY_DASSERT(queries_prefix_end.size() == num_queries); @@ -1302,8 +1288,7 @@ HWY_NOINLINE void Transformer( } } - RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(), - activations.x.All(), model_dim); + RMSNormInplaceBatched(weights.final_norm_scale, activations.x); if (activations_observer) { activations_observer(queries_pos, -1, activations); @@ -1395,18 +1380,18 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs& weights, runtime_config.activations_observer); // queries_pos are incremented by Transformer. + HWY_DASSERT(num_queries == activations.x.Rows()); bool all_queries_eos = true; { PROFILER_ZONE("Gen.EmbeddingMatmul"); // Compute logits from last layer activations. - MatMul(ConstMatFromBatch(num_queries, activations.x), - ConstMatFromWeights(weights.embedder_input_embedding), + MatMul(activations.x, weights.embedder_input_embedding, /*add=*/nullptr, *activations.env, - RowPtrFromBatch(activations.env->ctx.allocator, activations.logits)); + RowPtrFromMat(activations.env->ctx.allocator, activations.logits)); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { - float* HWY_RESTRICT logits = activations.logits.Batch(query_idx); + float* HWY_RESTRICT logits = activations.logits.Row(query_idx); MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size); const TokenAndProb tp = sample_token(logits, vocab_size); timing_info.NotifyGenerated(); @@ -1460,7 +1445,7 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, const size_t num_queries = queries_prompt.size(); HWY_ASSERT(num_queries <= 4096); // TokenStreamer uses BitSet4096. - HWY_ASSERT(num_queries <= activations.x.BatchSize()); + HWY_ASSERT(num_queries <= activations.x.Rows()); HWY_ASSERT(queries_pos_in.size() == num_queries); HWY_ASSERT(kv_caches.size() == num_queries); const hwy::Divisor div_seq_len(static_cast(kv_caches[0].seq_len)); @@ -1475,12 +1460,11 @@ void GenerateT(const ModelStore& model, const ModelWeightsPtrs& weights, // If tbatch is larger than the qbatch we already have in `activations`, then // allocate prefill_activations, otherwise reuse. const bool use_prefill_activations = - runtime_config.prefill_tbatch_size > activations.x.BatchSize(); - Activations prefill_activations(weights.weights_config); - if (use_prefill_activations) { - prefill_activations.Allocate(runtime_config.prefill_tbatch_size, - activations.env); - } + runtime_config.prefill_tbatch_size > activations.x.Rows(); + Activations prefill_activations( + weights.weights_config, + use_prefill_activations ? runtime_config.prefill_tbatch_size : 0, + activations.env); Prefill(queries_prompt, queries_mutable_pos, queries_prefix_end, query_idx_start, weights, use_prefill_activations ? prefill_activations : activations, @@ -1534,8 +1518,7 @@ void GenerateSingleT(const ModelStore& model, const size_t qbatch_start = 0; // TODO: move into Gemma? - Activations activations(model.Config()); - activations.Allocate(kNumQueries, env); + Activations activations(model.Config(), kNumQueries, env); const QueriesPromptTokens queries_prompt(&prompt, kNumQueries); QueriesPos queries_pos(&pos, kNumQueries); @@ -1558,7 +1541,7 @@ void GenerateBatchT(const ModelStore& model, TimingInfo& timing_info) { const size_t num_queries = queries_prompt.size(); HWY_ASSERT(queries_pos.size() == num_queries); - HWY_ASSERT(kv_caches.size() == num_queries); + HWY_ASSERT(kv_caches.size() >= num_queries); // Griffin does not support query batching. size_t max_qbatch_size = runtime_config.decode_qbatch_size; for (const LayerConfig& layer_config : model.Config().layer_configs) { @@ -1568,14 +1551,14 @@ void GenerateBatchT(const ModelStore& model, } } - Activations activations(model.Config()); - activations.Allocate(max_qbatch_size, env); + Activations activations(model.Config(), max_qbatch_size, env); for (size_t qbatch_start = 0; qbatch_start < num_queries; qbatch_start += max_qbatch_size) { // Generate one batch of tokens from `qbatch_size` queries. const size_t qbatch_size = HWY_MIN(num_queries - qbatch_start, max_qbatch_size); + activations.SetBatchSize(qbatch_size); const QueriesPromptTokens qbatch_prompts(&queries_prompt[qbatch_start], qbatch_size); QueriesPos qbatch_pos(&queries_pos[qbatch_start], qbatch_size); @@ -1601,8 +1584,7 @@ void GenerateImageTokensT(const ModelStore& model, ModelConfig vit_config = GetVitConfig(model.Config()); prefill_runtime_config.prefill_tbatch_size = vit_config.seq_len / (vit_config.pool_dim * vit_config.pool_dim); - Activations prefill_activations(vit_config); - prefill_activations.Allocate(vit_config.seq_len, env); + Activations prefill_activations(vit_config, vit_config.seq_len, env); // Weights are for the full PaliGemma model, not just the ViT part. PrefillVit(weights, prefill_runtime_config, image, image_tokens, prefill_activations); diff --git a/gemma/gemma.h b/gemma/gemma.h index b18eb60..cf027d1 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -32,7 +32,6 @@ #include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" #include "util/basics.h" // TokenAndProb -#include "util/mat.h" // RowVectorBatch #include "util/threading_context.h" #include "hwy/timer.h" // IWYU pragma: end_exports diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index b5a8148..d1e1964 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -25,10 +25,11 @@ #include #include -#include "io/io.h" // Path -#include "ops/matmul.h" // MMStorage::kMax* +#include "io/io.h" // Path +#include "ops/matmul.h" // MMStorage::kMax* #include "util/args.h" -#include "util/basics.h" // Tristate +#include "util/basics.h" // Tristate +#include "util/mat.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" // HWY_ABORT @@ -74,9 +75,9 @@ using QueriesPromptTokens = hwy::Span; using QueriesToken = hwy::Span; using QueriesPos = hwy::Span; -// ImageTokens are represented as a RowVectorBatch, where each "batch" index -// corresponds to a token for an image patch as computed by the image encoder. -using ImageTokens = RowVectorBatch; +// ImageTokens are represented as a matrix, where each row corresponds +// to a token for an image patch as computed by the image encoder. +using ImageTokens = MatStorageT; // StreamFunc is called with (token, probability). For prompt tokens, // probability is 0.0f. StreamFunc should return false to stop generation and diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index d3c2372..3992270 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -15,91 +15,69 @@ #include "gemma/kv_cache.h" -#include +#include // std::copy #include "gemma/configs.h" +#include "util/mat.h" // ZeroInit #include "hwy/aligned_allocator.h" #include "hwy/base.h" // ZeroBytes namespace gcpp { void KVCache::ZeroGriffinCache() { - if (conv1d_cache_size != 0) { - hwy::ZeroBytes(conv1d_cache.get(), - conv1d_cache_size * sizeof(conv1d_cache[0])); - } - if (rglru_cache_size != 0) { - hwy::ZeroBytes(rglru_cache.get(), - rglru_cache_size * sizeof(rglru_cache[0])); + if (conv1d_cache.HasPtr()) ZeroInit(conv1d_cache); + if (rglru_cache.HasPtr()) ZeroInit(rglru_cache); +} + +static size_t GriffinConv1dCols(const ModelConfig& config) { + size_t conv1d_width = 0; + for (const auto& layer_config : config.layer_configs) { + conv1d_width = HWY_MAX(conv1d_width, layer_config.conv1d_width); } + return conv1d_width == 0 ? 0 : conv1d_width - 1; } // prefill_tbatch_size is the maximum number of tokens from one query to // prefill at a time. -KVCache KVCache::Create(const ModelConfig& weights_config, - size_t prefill_tbatch_size) { - KVCache kv_cache = {}; - - const size_t size_cache_pos = weights_config.CachePosSize(); +KVCache::KVCache(const ModelConfig& config, size_t prefill_tbatch_size) + : griffin_layers( + config.NumLayersOfType(LayerAttentionType::kGriffinRecurrentBlock)), + griffin_conv1d_cols(GriffinConv1dCols(config)), + // TODO(patrickms): Add query batching support for Griffin. + conv1d_cache( + "conv1d_cache", + Extents2D(griffin_layers, griffin_conv1d_cols * config.model_dim), + MatPadding::kOdd), + rglru_cache("rglru_cache", Extents2D(griffin_layers, config.model_dim), + MatPadding::kOdd) { + // TODO: move to MatStorageT. + const size_t size_cache_pos = config.CachePosSize(); if (size_cache_pos != 0) { // Allocate more so that prefill can always access one batch, even if // near the end of the sequence. - kv_cache.seq_len = weights_config.seq_len + prefill_tbatch_size; - kv_cache.kv_cache = - hwy::AllocateAligned(kv_cache.seq_len * size_cache_pos); + seq_len = config.seq_len + prefill_tbatch_size; + kv_cache = hwy::AllocateAligned(seq_len * size_cache_pos); } - - const size_t num_griffin_layers = weights_config.NumLayersOfType( - LayerAttentionType::kGriffinRecurrentBlock); - // TODO(patrickms): Add query batching support for Griffin. - if (num_griffin_layers > 0) { - uint32_t conv1d_width = 0; - for (const auto& layer_config : weights_config.layer_configs) { - conv1d_width = std::max(conv1d_width, layer_config.conv1d_width); - } - const size_t conv1d_cache_size = - num_griffin_layers * (conv1d_width == 0 ? 0 : conv1d_width - 1) * - weights_config.model_dim; - kv_cache.conv1d_cache_size = conv1d_cache_size; - if (conv1d_cache_size != 0) { - kv_cache.conv1d_cache = hwy::AllocateAligned(conv1d_cache_size); - } - - const size_t rglru_cache_size = - num_griffin_layers * weights_config.model_dim; - kv_cache.rglru_cache_size = rglru_cache_size; - if (rglru_cache_size != 0) { - kv_cache.rglru_cache = hwy::AllocateAligned(rglru_cache_size); - } - } // num_griffin_layers - - return kv_cache; } KVCache KVCache::Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size) { - KVCache kv_cache_copy = Create(weights_config, prefill_tbatch_size); + KVCache copy(weights_config, prefill_tbatch_size); const size_t size_cache_pos = weights_config.CachePosSize(); if (size_cache_pos != 0) { std::copy(kv_cache.get(), kv_cache.get() + size_cache_pos * seq_len, - kv_cache_copy.kv_cache.get()); + copy.kv_cache.get()); } - const size_t num_griffin_layers = weights_config.NumLayersOfType( - LayerAttentionType::kGriffinRecurrentBlock); - if (num_griffin_layers > 0) { - if (conv1d_cache_size != 0) { - std::copy(conv1d_cache.get(), conv1d_cache.get() + conv1d_cache_size, - kv_cache_copy.conv1d_cache.get()); - } - if (rglru_cache_size != 0) { - std::copy(rglru_cache.get(), - rglru_cache.get() + rglru_cache_size * sizeof(rglru_cache[0]), - kv_cache_copy.rglru_cache.get()); - } + if (conv1d_cache.HasPtr()) { + CopyMat(conv1d_cache, copy.conv1d_cache); } - return kv_cache_copy; + if (rglru_cache.HasPtr()) { + CopyMat(rglru_cache, copy.rglru_cache); + } + + return copy; } } // namespace gcpp diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index 907bee3..028a6f1 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -19,33 +19,31 @@ #include #include "gemma/configs.h" // ModelConfig +#include "util/mat.h" #include "hwy/aligned_allocator.h" namespace gcpp { struct KVCache { - size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size + KVCache() = default; // for std::vector. + KVCache(const ModelConfig& weights_config, size_t prefill_tbatch_size); - // seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2 - hwy::AlignedFreeUniquePtr kv_cache; - - // (kConv1dWidth - 1) * kModelDim * kGriffinLayers - hwy::AlignedFreeUniquePtr conv1d_cache; - size_t conv1d_cache_size = 0; - - // kModelDim * kGriffinLayers - hwy::AlignedFreeUniquePtr rglru_cache; - size_t rglru_cache_size = 0; + // Returns a deep copy of the KVCache. + KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); + size_t griffin_layers = 0; + size_t griffin_conv1d_cols = 0; + // griffin_layers, griffin_conv1d_cols * config.model_dim + MatStorageT conv1d_cache; + MatStorageT rglru_cache; // griffin_layers, config.model_dim // Zero-initialize the Griffin recurrent block cache, i.e. the conv1d_cache // and rglru_cache. void ZeroGriffinCache(); - static KVCache Create(const ModelConfig& weights_config, - size_t prefill_tbatch_size); + size_t seq_len = 0; // = kSeqLen + prefill_tbatch_size - // Returns a deep copy of the KVCache. - KVCache Copy(const ModelConfig& weights_config, size_t prefill_tbatch_size); + // seq_len * kGemmaLayers * kKVHeads * kQKVDim * 2 + hwy::AlignedFreeUniquePtr kv_cache; }; } // namespace gcpp diff --git a/gemma/run.cc b/gemma/run.cc index 74a9c54..13c96c3 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -95,24 +95,25 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; + const ModelConfig& config = gemma.GetModelConfig(); std::mt19937 gen; InitGenerator(inference, gen); const bool have_image = !inference.image_file.path.empty(); Image image; - ImageTokens image_tokens; + const size_t pool_dim = config.vit_config.pool_dim; + ImageTokens image_tokens( + "image_tokens", + have_image ? Extents2D(config.vit_config.seq_len / (pool_dim * pool_dim), + config.model_dim) + : Extents2D(0, 0), + MatPadding::kOdd); if (have_image) { - size_t pool_dim = gemma.GetModelConfig().vit_config.pool_dim; - image_tokens = - ImageTokens(gemma.Env().ctx.allocator, - Extents2D(gemma.GetModelConfig().vit_config.seq_len / - (pool_dim * pool_dim), - gemma.GetModelConfig().model_dim)); - HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA || - gemma.GetModelConfig().wrapping == PromptWrapping::GEMMA_VLM); + HWY_ASSERT(config.wrapping == PromptWrapping::PALIGEMMA || + config.wrapping == PromptWrapping::GEMMA_VLM); HWY_ASSERT(image.ReadPPM(inference.image_file.path)); - const size_t image_size = gemma.GetModelConfig().vit_config.image_size; + const size_t image_size = config.vit_config.image_size; image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.gen = &gen, .verbosity = inference.verbosity, @@ -138,7 +139,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, std::cerr << "." << std::flush; } return true; - } else if (gemma.GetModelConfig().IsEOS(token)) { + } else if (config.IsEOS(token)) { if (inference.verbosity >= 2) { std::cout << "\n[ End ]\n"; } @@ -191,8 +192,8 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, size_t prefix_end = 0; if (have_image) { prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), - gemma.GetModelConfig().wrapping, abs_pos, - prompt_string, image_tokens.BatchSize()); + config.wrapping, abs_pos, prompt_string, + image_tokens.Rows()); runtime_config.image_tokens = &image_tokens; prompt_size = prompt.size(); // The end of the prefix for prefix-LM style attention in Paligemma. @@ -203,8 +204,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, // runtime_config.prefill_tbatch_size = prompt_size; } else { prompt = WrapAndTokenize(gemma.Tokenizer(), gemma.ChatTemplate(), - gemma.GetModelConfig().wrapping, abs_pos, - prompt_string); + config.wrapping, abs_pos, prompt_string); prompt_size = prompt.size(); } @@ -228,8 +228,7 @@ void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, } // Prepare for the next turn. Works only for PaliGemma. - if (!inference.multiturn || - gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA) { + if (!inference.multiturn || config.wrapping == PromptWrapping::PALIGEMMA) { abs_pos = 0; // Start a new turn at position 0. InitGenerator(inference, gen); } else { @@ -254,8 +253,7 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, MatMulEnv env(MakeMatMulEnv(threading)); if (inference.verbosity >= 2) env.print_best = true; const Gemma gemma(loader, env); - KVCache kv_cache = - KVCache::Create(gemma.GetModelConfig(), inference.prefill_tbatch_size); + KVCache kv_cache(gemma.GetModelConfig(), inference.prefill_tbatch_size); if (inference.verbosity >= 1) { std::string instructions = diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index d1014ef..fa49b2d 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -92,9 +92,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - RowVectorBatch c_slow_batch = - AllocateAlignedRows(allocator, C_extents); - RowVectorBatch c_batch = AllocateAlignedRows(allocator, C_extents); + MatStorageT c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd); + MatStorageT c_batch("c_batch", C_extents, MatPadding::kOdd); MatStorageT add_storage("add", Extents2D(), MatPadding::kPacked); if (add) { @@ -104,11 +103,9 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { MatStorageT a = GenerateMat(A_extents, pool); MatStorageT b_trans = GenerateTransposedMat(B_extents, pool); - const auto A = ConstMatFromWeights(a); - const auto B = ConstMatFromWeights(b_trans); const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C = RowPtrFromBatch(allocator, c_batch); + const RowPtr C = RowPtrFromMat(allocator, c_batch); // Fewer reps for large batch sizes, which take longer. const size_t num_samples = M < 32 ? 20 : 12; @@ -118,7 +115,8 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel); + BindB(allocator, B_extents.rows, sizeof(TC), ConstMat(b_trans), + env.parallel); BindC(allocator, A_extents.rows, C, env.parallel); Tristate use_spinning = Tristate::kDefault; @@ -133,7 +131,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Until enough samples collected *after* autotuning finished: while (times.size() < num_samples) { const double t0 = hwy::platform::Now(); - per_key = MatMul(A, B, add_row, env, C); + per_key = MatMul(a, b_trans, add_row, env, C); const double t1 = hwy::platform::Now(); double elapsed = t1 - t0; keep += hwy::ConvertScalarTo(C.Row(0)[hwy::Unpredictable1()]); diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 26771e8..52127a4 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -26,6 +26,7 @@ #include #include +#include "compression/compress.h" #include "compression/shared.h" #include "util/allocator.h" #include "util/test_util.h" @@ -999,7 +1000,6 @@ struct TestShortDotsT { const size_t N = hn::Lanes(d); const hn::ScalableTag df; // for CallDot - const Allocator& allocator = gcpp::ThreadingContext::Get().allocator; CompressWorkingSet work; std::mt19937 rng; rng.seed(12345); @@ -1010,22 +1010,22 @@ struct TestShortDotsT { // GenerateWellConditionedInputs calls DecompressAndZeroPad to `raw*`, // hence they require padding to one vector. const size_t padded_num = hwy::RoundUpTo(num, N); - const size_t packed_num = CompressedArrayElements(num); - RowVectorBatch raw_w(allocator, Extents2D(1, padded_num)); - RowVectorBatch raw_v(allocator, Extents2D(1, padded_num)); - RowVectorBatch weights(allocator, Extents2D(1, packed_num)); - const PackedSpan w(weights.Batch(0), packed_num); - RowVectorBatch vectors(allocator, Extents2D(1, num)); - const PackedSpan v(vectors.Batch(0), num); + MatStorageT raw_w("raw_w", padded_num); + MatStorageT raw_v("raw_v", padded_num); + MatStorageT weights("weights", padded_num); + const PackedSpan w = weights.Span(); + MatStorageT vectors("vectors", padded_num); + const PackedSpan v = vectors.Span(); - RowVectorBatch bufs(allocator, Extents2D(1, num)); - double* HWY_RESTRICT buf = bufs.Batch(0); + MatStorageT bufs("bufs", num); + double* HWY_RESTRICT buf = bufs.Packed(); for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { - GenerateWellConditionedInputs(num, raw_w.All(), rng, w, work); - GenerateWellConditionedInputs(num, raw_v.All(), rng, v, work); + GenerateWellConditionedInputs(num, raw_w.Packed(), rng, w, work); + GenerateWellConditionedInputs(num, raw_v.Packed(), rng, v, work); - const float dot_exact = ExactDot(raw_w.All(), raw_v.All(), num, buf); + const float dot_exact = + ExactDot(raw_w.Packed(), raw_v.Packed(), num, buf); float dots[kVariants]; for (size_t variant = 0; variant < kVariants; ++variant) { // Here Packed is not always float, so we must not call kDouble. @@ -1106,7 +1106,6 @@ void TestAllDot() { threading_args.max_lps = kMaxWorkers - 1; ThreadingContext::SetArgs(threading_args); ThreadingContext& ctx = ThreadingContext::Get(); - const Allocator& allocator = ctx.allocator; { // ensure no profiler zones are active const hn::ScalableTag df; @@ -1118,16 +1117,17 @@ void TestAllDot() { constexpr size_t kReps = hn::AdjustedReps(40); const size_t num = 24 * 1024; - RowVectorBatch a(allocator, Extents2D(kMaxWorkers, num)); - RowVectorBatch b(allocator, Extents2D(kMaxWorkers, num)); - RowVectorBatch bufs(allocator, Extents2D(kMaxWorkers, num)); + MatStorageT a("a", Extents2D(kMaxWorkers, num), MatPadding::kOdd); + MatStorageT b("b", Extents2D(kMaxWorkers, num), MatPadding::kOdd); + MatStorageT bufs("bufs", Extents2D(kMaxWorkers, num), + MatPadding::kOdd); std::array all_stats; ctx.pools.Cluster(0, 0).Run( 0, kReps, [&](const uint32_t rep, size_t thread) { - float* HWY_RESTRICT pa = a.Batch(thread); - float* HWY_RESTRICT pb = b.Batch(thread); - double* HWY_RESTRICT buf = bufs.Batch(thread); + float* HWY_RESTRICT pa = a.Row(thread); + float* HWY_RESTRICT pb = b.Row(thread); + double* HWY_RESTRICT buf = bufs.Row(thread); const PackedSpan a_span(pa, num); DotStats& stats = all_stats[thread]; const double cond = diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index ce285ca..f02edef 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -693,7 +693,6 @@ class MMScaleDemoteAdd { // We manually unroll 2x for higher IPC in batch=1. size_t col_c = range_nc.begin(); if (HWY_LIKELY(range_nc.Num() >= 2 * ND)) { - HWY_UNROLL(1) for (; col_c <= range_nc.end() - 2 * ND; col_c += 2 * ND) { VD a0, a1; // unused if !kAdd if constexpr (kAdd) { @@ -860,9 +859,8 @@ class MMScaleDemoteAdd { class MMPerPackage { public: template - MMPerPackage(const ConstMat& A, const MMArgs& args, - const MMConfig& config, size_t pkg_idx, - const IndexRange& range_np) + MMPerPackage(const MatPtrT& A, const MMArgs& args, const MMConfig& config, + size_t pkg_idx, const IndexRange& range_np) : args_(args), pkg_idx_(pkg_idx), // May be overwritten with a view of A, if already BF16. @@ -1114,12 +1112,12 @@ class MMPerPackage { }); } - // Decompresses all `M x K` from `A` into `pkg_A`. Assumes `TA` is a seekable + // Decompresses all `M x K` from `A` into `A_`. Assumes `TA` is a seekable // type (i.e., not NUQ) so we can use pointer arithmetic. template - HWY_NOINLINE void DoDecompressA(const ConstMat& A, MMParA par_a) const { - const IndexRange all_M(0, A.extents.rows); - const IndexRange all_K(0, A.extents.cols); + HWY_NOINLINE void DoDecompressA(const MatPtrT& A, MMParA par_a) const { + const IndexRange all_M(0, A.Rows()); + const IndexRange all_K(0, A.Cols()); HWY_DASSERT(all_K.Num() == A_.Cols()); const hn::ScalableTag dbf; @@ -1131,10 +1129,9 @@ class MMPerPackage { const size_t col0 = range_K.begin(); const size_t cols = range_K.Num(); // otherwise, padding overwrites neighbors - HWY_DASSERT(cols % NBF == 0 || cols == A.extents.cols); + HWY_DASSERT(cols % NBF == 0 || cols == A.Cols()); for (size_t row_a : range_M) { - const PackedSpan from = - MakeSpan(A.ptr + A.Row(row_a) + col0, cols); + const PackedSpan from = MakeSpan(A.Row(row_a) + col0, cols); BF16* HWY_RESTRICT to = A_.Row(row_a) + col0; DecompressAndZeroPad(dbf, from, 0, to, cols); // Verify that we zero-padded. @@ -1174,18 +1171,14 @@ class MMPerPackage { // Autotuning wrapper for `DoDecompressA`. template - HWY_INLINE RowPtrBF DecompressA(const ConstMat& A) const { + HWY_INLINE RowPtrBF DecompressA(const MatPtrT& A) const { const Allocator& allocator = args_.env->ctx.allocator; MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; // If already BF16, maybe return a view: if constexpr (hwy::IsSame()) { // Only if no zero-padding required. const size_t NBF = hn::Lanes(hn::ScalableTag()); - if (HWY_LIKELY(A.extents.cols % NBF == 0)) { - const BF16* pos = A.ptr + A.Row(0); - return RowPtrBF(allocator, const_cast(pos), A.extents.cols, - A.Stride()); - } + if (HWY_LIKELY(A.Cols() % NBF == 0)) return RowPtrFromMat(allocator, A); } if (HWY_LIKELY(autotune.Best())) { @@ -1196,7 +1189,7 @@ class MMPerPackage { // First call: generate candidates. if (HWY_UNLIKELY(!autotune.HasCandidates())) { std::vector candidates = {MMParA::kK1, MMParA::kK2, MMParA::kK4}; - if (A.extents.rows == 1) { + if (A.Rows() == 1) { candidates.push_back(MMParA::kNone); } else { candidates.push_back(MMParA::kM); @@ -1247,7 +1240,7 @@ class MMPerPackage { const MMArgs args_; // copy for locality const size_t pkg_idx_; - RowPtrBF A_; // points into A or storage. + RowPtrBF A_; // points into A or pkg_A. const IndexRange range_np_; // From MMConfig: @@ -1276,9 +1269,8 @@ struct MMImpl { // Called from `MatMul` from two places: either with the next autotune config, // or with the best config. template - static HWY_NOINLINE void DoMatMul(const ConstMat& A, - const ConstMat& B, const RowPtr& C, - const MMArgs& args, + static HWY_NOINLINE void DoMatMul(const MatPtrT& A, const ConstMat& B, + const RowPtr& C, const MMArgs& args, const MMConfig& config) { MMZone matmul_zone; matmul_zone.MaybeEnter("MM.DoMatMul", args); @@ -1313,7 +1305,7 @@ struct MMImpl { // // Uses considerable stack space: at least 40 KiB per thread. template -HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, +HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const ConstMat& B, const float* HWY_RESTRICT add, MatMulEnv& env, const RowPtr& C) { const Allocator& allocator = env.ctx.allocator; @@ -1340,7 +1332,7 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, MMPerKey& per_key = env.per_key[index]; MMAutoTune& tuner = per_key.autotune; - const MMArgs args(env, per_key, static_cast(A.scale) * B.scale, add, + const MMArgs args(env, per_key, static_cast(A.Scale()) * B.scale, add, env.storage.Partial()); if (HWY_LIKELY(tuner.Best())) { MMImpl::DoMatMul(A, B, C, args, *tuner.Best()); @@ -1396,6 +1388,13 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, return &per_key; } +template +HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, + const float* HWY_RESTRICT add, MatMulEnv& env, + const RowPtr& C) { + return MatMul(A, ConstMat(B), add, env, C); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/ops/matmul.h b/ops/matmul.h index b681fe5..8e50c60 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -21,6 +21,7 @@ #include #include +#include // std::unique_ptr #include // IWYU pragma: begin_exports @@ -207,24 +208,28 @@ class MMStorage { // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. static constexpr size_t kMaxKC = 8 * 1024; + // Internally threaded; must not be called concurrently with the same + // `ThreadingContext` (used via `parallel`). MMStorage(const Allocator& allocator, MMParallel& parallel) // Per-worker copies of `partial` would be wasteful. We instead allocate // one instance of the maximum matrix extents because threads write at // false-sharing-free granularity. - : partial_storage_( - AllocateAlignedRows(allocator, Extents2D(kMaxM, kMaxN))), + : partial_storage_("partial_storage", Extents2D(kMaxM, kMaxN), + MatPadding::kOdd), // Same stride independent of the actual C.Cols() so we can pre-bind. - partial_(allocator, partial_storage_.All(), kMaxN, - StrideForCyclicOffsets(kMaxN, allocator.Quantum())) { + partial_(allocator, partial_storage_.Row(0), kMaxN, + partial_storage_.Stride()) { // Per-package allocation so each can decompress A into its own copy. parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { - pkg_A_[pkg_idx] = - AllocateAlignedRows(allocator, Extents2D(kMaxM, kMaxK)); + pkg_A_[pkg_idx].reset(new MatStorageT( + "pkg_A", Extents2D(kMaxM, kMaxK), MatPadding::kOdd)); if (allocator.ShouldBind()) { const size_t node = parallel.Node(pkg_idx); - if (!allocator.BindMemory(pkg_A_[pkg_idx].All(), - pkg_A_[pkg_idx].NumBytes(), node)) { + size_t bytes = pkg_A_[pkg_idx]->Rows() * pkg_A_[pkg_idx]->Stride() * + pkg_A_[pkg_idx]->ElementBytes(); + bytes = hwy::RoundDownTo(bytes, allocator.QuantumBytes()); + if (!allocator.BindMemory(pkg_A_[pkg_idx]->Row(0), bytes, node)) { HWY_WARN("Failed to bind memory for package %zu", pkg_idx); } } @@ -234,22 +239,20 @@ class MMStorage { BindC(allocator, kMaxM, partial_, parallel); } - // Returns per-package matrix view. Non-const so that `RowVectorBatch` is - // non-const, because `RowPtr` requires a non-const pointer. + // Returns per-package matrix view. RowPtrBF A(const Allocator& allocator, size_t pkg_idx, - const Extents2D& extents) { + const Extents2D& extents) const { HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.cols <= kMaxK); - const size_t stride = - StrideForCyclicOffsets(extents.cols, allocator.Quantum()); - return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride); + return RowPtrBF(allocator, const_cast(pkg_A_[pkg_idx]->Row(0)), + extents.cols, pkg_A_[pkg_idx]->Stride()); } RowPtrD Partial() const { return partial_; } private: - RowVectorBatch pkg_A_[MMParallel::kMaxPackages]; - RowVectorBatch partial_storage_; + std::unique_ptr> pkg_A_[MMParallel::kMaxPackages]; + MatStorageT partial_storage_; RowPtrD partial_; }; @@ -608,6 +611,8 @@ struct MMPerKey { // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive // `MatMulEnv`. struct MatMulEnv { + // Internally threaded; must not be called concurrently with the same + // `ThreadingContext`. explicit MatMulEnv(ThreadingContext& ctx); ThreadingContext& ctx; @@ -679,8 +684,8 @@ struct MMZone { #endif // PROFILER_ENABLED // Used for the A and B arguments of `MatMul`, which are always const. -// Create via MakeConstMat. This differs from `RowPtr` in that it supports the -// `ofs` required for compressed T. +// This differs from `RowPtr` in supporting the `ofs` required for compressed T. +// TODO: remove after splitting W1/W2 and updating QDotK to RowPtr. template struct ConstMat { ConstMat() = default; @@ -689,6 +694,12 @@ struct ConstMat { HWY_DASSERT(ptr != nullptr); HWY_DASSERT(stride >= extents.cols); } + // Non-explicit so that we can pass `MatPtr` directly to MatMul. + ConstMat(const MatPtrT& m) + : ConstMat(const_cast(m.Row(0)), m.Extents(), m.Stride()) { + scale = m.Scale(); + } + size_t Row(size_t r) const { if constexpr (HWY_IS_DEBUG_BUILD) { if (r >= extents.rows) { @@ -727,31 +738,6 @@ struct ConstMat { size_t ofs; }; -// For deducing T. -template -ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, - size_t stride) { - return ConstMat(ptr, extents, stride); -} - -// For A argument to MatMul (activations). -template -ConstMat ConstMatFromBatch(size_t batch_size, - const RowVectorBatch& row_vectors) { - HWY_DASSERT(batch_size <= row_vectors.BatchSize()); - return MakeConstMat(const_cast(row_vectors.Const()), - Extents2D(batch_size, row_vectors.Cols()), - row_vectors.Stride()); -} - -template -ConstMat ConstMatFromWeights(const MatPtrT& m) { - ConstMat mat = - MakeConstMat(const_cast(m.Row(0)), m.Extents(), m.Stride()); - mat.scale = m.Scale(); - return mat; -} - template void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC, const ConstMat& B, MMParallel& parallel) { diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index f245cf0..668a983 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -57,10 +57,10 @@ namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; // Returns 1-norm, used for estimating tolerable numerical differences. -double MaxRowAbsSum(const RowVectorBatch& a) { +double MaxRowAbsSum(const MatStorageT& a) { double max_row_abs_sum = 0.0; - for (size_t r = 0; r < a.BatchSize(); r++) { - const float* row = a.Batch(r); + for (size_t r = 0; r < a.Rows(); r++) { + const float* row = a.Row(r); double row_abs_sum = 0.0; for (size_t c = 0; c < a.Cols(); c++) { row_abs_sum += hwy::ScalarAbs(row[c]); @@ -71,11 +71,11 @@ double MaxRowAbsSum(const RowVectorBatch& a) { } // Returns the maximum absolute value of `a`. -float MaxAbs(const RowVectorBatch& a) { +float MaxAbs(const MatStorageT& a) { float max_abs = 0.0f; for (size_t c = 0; c < a.Cols(); c++) { - for (size_t r = 0; r < a.BatchSize(); r++) { - const float* row = a.Batch(r); + for (size_t r = 0; r < a.Rows(); r++) { + const float* row = a.Row(r); max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c])); } } @@ -84,33 +84,29 @@ float MaxAbs(const RowVectorBatch& a) { // B is already transposed. template -void AssertClose(const ConstMat& A, const ConstMat& B, +void AssertClose(const MatPtrT& A, const MatPtrT& B, const RowPtr& C_slow, const RowPtr& C, int line) { - const Allocator& allocator = ThreadingContext::Get().allocator; const hn::ScalableTag df; - const size_t cols = A.extents.cols; - const size_t B_rows = B.extents.rows; + const size_t cols = A.Cols(); + const size_t B_rows = B.Rows(); // Round up for DecompressAndZeroPad. - RowVectorBatch a_batch = - AllocateAlignedRows(allocator, A.extents); - RowVectorBatch b_trans_batch = - AllocateAlignedRows(allocator, B.extents); - RowVectorBatch c_batch = - AllocateAlignedRows(allocator, Extents2D(A.extents.rows, B_rows)); - RowVectorBatch c_slow_batch = - AllocateAlignedRows(allocator, Extents2D(A.extents.rows, B_rows)); - HWY_ASSERT(A.ofs == 0 && B.ofs == 0); - for (size_t m = 0; m < A.extents.rows; ++m) { - DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0, - a_batch.Batch(m), cols); - DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m), + MatStorageT a_batch("a_batch", A.Extents(), MatPadding::kOdd); + MatStorageT b_trans_batch("b_trans_batch", B.Extents(), + MatPadding::kOdd); + MatStorageT c_batch("c_batch", Extents2D(A.Rows(), B_rows), + MatPadding::kOdd); + MatStorageT c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows), + MatPadding::kOdd); + for (size_t m = 0; m < A.Rows(); ++m) { + DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols); + DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m), B_rows); DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0, - c_slow_batch.Batch(m), B_rows); + c_slow_batch.Row(m), B_rows); } for (size_t n = 0; n < B_rows; ++n) { - DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0, - b_trans_batch.Batch(n), cols); + DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n), + cols); } // MatMul rounds inputs to BF16, so error is proportional to the max input @@ -130,10 +126,10 @@ void AssertClose(const ConstMat& A, const ConstMat& B, } const double max_rel = 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); - for (size_t r = 0; r < A.extents.rows; r++) { - const float* expected_row = c_slow_batch.Batch(r); - const float* actual_row = c_batch.Batch(r); - for (size_t c = 0; c < B.extents.rows; c++) { + for (size_t r = 0; r < A.Rows(); r++) { + const float* expected_row = c_slow_batch.Row(r); + const float* actual_row = c_batch.Row(r); + for (size_t c = 0; c < B.Rows(); c++) { const double expected_value = static_cast(expected_row[c]); const double actual_value = static_cast(actual_row[c]); const bool in_range = expected_value - tolerance <= actual_value && @@ -157,18 +153,17 @@ void AssertClose(const ConstMat& A, const ConstMat& B, // B is already transposed. template -HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, +HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, const float* HWY_RESTRICT add_row, MatMulEnv& env, const RowPtr& C) { // TA can be any Packed except NuqStream because it uses pointer // arithmetic, because it is the second argument to Dot, which does not // support a v_ofs. static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32"); - const float scale = A.scale * B.scale; + const float scale = A.Scale() * B.Scale(); const hn::ScalableTag df; // lane type is ignored - const PackedSpan b_span = - MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows); + const PackedSpan b_span = B.Span(); const IndexRange all_rows_c(0, A.Extents().rows); const IndexRange all_cols_c(0, C.Cols()); @@ -191,8 +186,8 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, for (size_t c : cols_c) { const float add = add_row ? add_row[c] : 0.0f; C_row[c] = hwy::ConvertScalarTo( - add + scale * Dot(df, b_span, c * B.Stride(), - A.ptr + A.Row(r), A.extents.cols)); + add + scale * Dot(df, b_span, c * B.Stride(), A.Row(r), + A.Cols())); } } }); @@ -225,26 +220,23 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatStorageT a(GenerateMat(A_extents, pool)); MatStorageT b_trans(GenerateTransposedMat(B_extents, pool)); - RowVectorBatch c_slow_batch = - AllocateAlignedRows(allocator, C_extents); - RowVectorBatch c_batch = AllocateAlignedRows(allocator, C_extents); + MatStorageT c_slow_batch("c_slow_batch", C_extents, MatPadding::kOdd); + MatStorageT c_batch("c_batch", C_extents, MatPadding::kOdd); MatStorageT add_storage = add ? GenerateMat(Extents2D(1, cols_bc), pool) : MatStorageT("add", Extents2D(), MatPadding::kPacked); add_storage.SetScale(1.0f); - const auto A = ConstMatFromWeights(a); - const auto B = ConstMatFromWeights(b_trans); const float* add_row = add ? add_storage.PackedScale1() : nullptr; - const RowPtr C_slow = RowPtrFromBatch(allocator, c_slow_batch); - const RowPtr C = RowPtrFromBatch(allocator, c_batch); + const RowPtr C_slow = RowPtrFromMat(allocator, c_slow_batch); + const RowPtr C = RowPtrFromMat(allocator, c_batch); - MatMulSlow(A, B, add_row, env, C_slow); + MatMulSlow(a, b_trans, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. for (size_t rep = 0; rep < 16; ++rep) { - MMPerKey* per_key = MatMul(A, B, add_row, env, C); - AssertClose(A, B, C_slow, C, line); + MMPerKey* per_key = MatMul(a, b_trans, add_row, env, C); + AssertClose(a, b_trans, C_slow, C, line); if (per_key->autotune.Best()) break; } } diff --git a/ops/ops-inl.h b/ops/ops-inl.h index acd2f5c..9308685 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -189,10 +189,11 @@ float RMSNormMul(const VT* HWY_RESTRICT x, size_t size) { } // namespace detail -template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x, - const WeightT* HWY_RESTRICT weight, - OutT* HWY_RESTRICT out, +// `x_ofs` is the offset within `x`, required for NuqStream. +template +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const XT* HWY_RESTRICT x, + const WT* HWY_RESTRICT weight, + size_t w_ofs, OT* HWY_RESTRICT out, const size_t size) { PROFILER_FUNC; @@ -203,17 +204,17 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x, const VF mul = hn::Set(df, detail::RMSNormMul(x, size)); - const auto packed_w = MakeSpan(weight, size); - const auto packed_v = MakeSpan(x, size); + const auto packed_x = MakeSpan(x, size); + const auto packed_w = MakeSpan(weight, w_ofs + size); const auto packed_out = MakeSpan(out, size); - HWY_DASSERT(size % (2 * MaxLanes(df)) == 0); + HWY_DASSERT(size % (2 * NF) == 0); for (size_t i = 0; i < size; i += 2 * NF) { - VF v0, v1, w0, w1; - Decompress2(df, packed_v, i, v0, v1); - Decompress2(df, packed_w, i, w0, w1); - const VF m0 = hn::Mul(mul, v0); - const VF m1 = hn::Mul(mul, v1); + VF x0, x1, w0, w1; + Decompress2(df, packed_x, i, x0, x1); + Decompress2(df, packed_w, w_ofs + i, w0, w1); + const VF m0 = hn::Mul(mul, x0); + const VF m1 = hn::Mul(mul, x1); // (1+weight) * m = m + weight*m = one FMA. const VF out0 = hn::MulAdd(m0, w0, m0); const VF out1 = hn::MulAdd(m1, w1, m1); @@ -222,10 +223,11 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(const VecT* HWY_RESTRICT x, } // Same as RMSNorm, but its HWY_RESTRICT forbids passing the same pointer. -template -HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( - const WeightT* HWY_RESTRICT weight, VecT* HWY_RESTRICT inout, - const size_t size) { +template +HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(const WT* HWY_RESTRICT weight, + size_t w_ofs, + XT* HWY_RESTRICT inout, + const size_t size) { PROFILER_FUNC; namespace hn = hwy::HWY_NAMESPACE; @@ -235,72 +237,112 @@ HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace( const VF mul = hn::Set(df, detail::RMSNormMul(inout, size)); - const auto packed_w = MakeSpan(weight, size); - const auto packed_v = MakeSpan(inout, size); + const auto packed_w = MakeSpan(weight, w_ofs + size); + const auto packed_x = MakeSpan(inout, size); - HWY_DASSERT(size % (2 * MaxLanes(df)) == 0); + HWY_DASSERT(size % (2 * NF) == 0); for (size_t i = 0; i < size; i += 2 * NF) { - VF v0, v1, w0, w1; - Decompress2(df, MakeConst(packed_v), i, v0, v1); - Decompress2(df, packed_w, i, w0, w1); - const VF m0 = hn::Mul(mul, v0); - const VF m1 = hn::Mul(mul, v1); + VF x0, x1, w0, w1; + Decompress2(df, packed_x, i, x0, x1); + Decompress2(df, packed_w, w_ofs + i, w0, w1); + const VF m0 = hn::Mul(mul, x0); + const VF m1 = hn::Mul(mul, x1); // (1+weight) * m = m + weight*m = one FMA. const VF out0 = hn::MulAdd(m0, w0, m0); const VF out1 = hn::MulAdd(m1, w1, m1); - Compress2(df, out0, out1, packed_v, i); + Compress2(df, out0, out1, packed_x, i); } } // Computes mean mu and mean of squares mu2 of a vector. Used in LayerNorm. -template -HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, T& mu, - T& mu2) { +template +HWY_NOINLINE void ComputeMoments(const XT* HWY_RESTRICT x, size_t size, + double& mu, double& mu2) { HWY_ASSERT(size > 0); - double sum = 0.0; - double sum2 = 0.0; - for (size_t i = 0; i < size; ++i) { - const float f = hwy::ConvertScalarTo(a[i]); - sum += f; - sum2 += f * f; - } - mu = sum / size; - mu2 = sum2 / size; + const hn::ScalableTag df; + + // Use the existing Sum and Dot kernels for simplicity. The second pass + // is likely not too expensive because it will be in L1. + const double sum = Sum(df, x, size); + // We only have one array, so calling `DecompressAndCall` instead of `Dot`` + // avoids loading the 'second' vector again. + const double sum2 = + DecompressAndCall(df, MakeSpan(x, size), DotKernelDouble()); + + const double inv_size = 1.0 / static_cast(size); + mu = sum * inv_size; + mu2 = sum2 * inv_size; } // Compare py/flax/linen/normalization.py. // out = (x - mean) * scale * rsqrt(var + epsilon) + bias -template -HWY_NOINLINE void ScalarLayerNorm(const VecT* x, - const WeightT* HWY_RESTRICT scale, - const WeightT* HWY_RESTRICT bias, - OutT* out, - size_t size) { - constexpr float kEps = 1e-6f; - VecT mu, mu2; - ScalarMus(x, size, mu, mu2); - VecT var = mu2 - mu * mu; - VecT zero = 0.0f; - var = HWY_MAX(var, zero); - var = 1.0f / sqrtf(var + kEps); - for (size_t j = 0; j < size; j++) { - const float v = hwy::ConvertScalarTo(x[j]); - const float s = hwy::ConvertScalarTo(scale[j]); - const float b = hwy::ConvertScalarTo(bias[j]); - out[j] = hwy::ConvertScalarTo((v - mu) * s * var + b); - } -} - -template -HWY_NOINLINE HWY_MAYBE_UNUSED void LayerNorm(const VecT* x, - const WeightT* HWY_RESTRICT weight, - const WeightT* HWY_RESTRICT bias, - OutT* out, - const size_t size) { +// x and out may be the same. +template +HWY_NOINLINE void LayerNorm(const XT* x, const WT* HWY_RESTRICT scale, + const WT* HWY_RESTRICT bias, OT* out, size_t size) { PROFILER_FUNC; - // For now we only delegate to the scalar version. - // TODO: implement vectorized version. - ScalarLayerNorm(x, weight, bias, out, size); + + namespace hn = hwy::HWY_NAMESPACE; + const hn::ScalableTag df; + using VF = hn::Vec; + const size_t NF = hn::Lanes(df); + + double mu, mu2; + ComputeMoments(x, size, mu, mu2); + double var = mu2 - mu * mu; + var = HWY_MAX(var, 0.0); + var = 1.0 / sqrt(var + 1E-6); + const VF vmu = hn::Set(df, static_cast(mu)); + const VF vvar = hn::Set(df, static_cast(var)); + const VF* HWY_RESTRICT pmu = &vmu; + const VF* HWY_RESTRICT pvar = &vvar; + + const auto packed_x = MakeSpan(x, size); + const auto packed_scale = MakeSpan(scale, size); + const auto packed_bias = MakeSpan(bias, size); + const auto packed_out = MakeSpan(out, size); + + // Loop body for one vector, called from main loop and remainder loop. + const auto norm = [pmu, pvar](VF x, VF s, VF add) HWY_ATTR -> VF { + const VF centered = hn::Sub(x, *pmu); + const VF mul = hn::Mul(s, *pvar); + return hn::MulAdd(centered, mul, add); + }; + + size_t i = 0; + if (size >= 2 * NF) { + for (; i <= size - 2 * NF; i += 2 * NF) { + VF x0, x1, s0, s1, add0, add1; + Decompress2(df, packed_x, i, x0, x1); + Decompress2(df, packed_scale, i, s0, s1); + Decompress2(df, packed_bias, i, add0, add1); + const VF n0 = norm(x0, s0, add0); + const VF n1 = norm(x1, s1, add1); + Compress2(df, n0, n1, packed_out, i); + } + } + + const size_t remaining = size - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_x[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_scale[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf_bias[2 * hn::MaxLanes(df)]; + HWY_ALIGN OT buf_out[2 * hn::MaxLanes(df)]; + DecompressAndZeroPad(df, packed_x, i, buf_x, remaining); + DecompressAndZeroPad(df, packed_scale, i, buf_scale, remaining); + DecompressAndZeroPad(df, packed_bias, i, buf_bias, remaining); + const VF x0 = hn::Load(df, buf_x); + const VF x1 = hn::Load(df, buf_x + NF); + const VF s0 = hn::Load(df, buf_scale); + const VF s1 = hn::Load(df, buf_scale + NF); + const VF add0 = hn::Load(df, buf_bias); + const VF add1 = hn::Load(df, buf_bias + NF); + const VF n0 = norm(x0, s0, add0); + const VF n1 = norm(x1, s1, add1); + Compress2(df, n0, n1, MakeSpan(buf_out, 2 * NF), 0); + hwy::CopyBytes(buf_out, out + i, remaining * sizeof(OT)); + } } static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings( @@ -447,39 +489,56 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( } // Simple loops unless/until batch sizes are large enough to parallelize. -template -void RMSNormBatched(size_t num_tokens, const float* activations, - const WeightT* weights, OutT* out, const size_t model_dim) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNorm(activations + token_idx * model_dim, weights, - out + token_idx * model_dim, model_dim); - } +template +void RMSNormBatched(const MatPtrT& activations, const MatPtr& weights, + MatPtrT& out) { + HWY_DASSERT(weights.Rows() == 1); + HWY_DASSERT(weights.Cols() == activations.Cols()); + HWY_DASSERT(activations.SameShape(out)); + + CallUpcasted(&weights, [&](const auto* weights_t) { + for (size_t token_idx = 0; token_idx < activations.Rows(); ++token_idx) { + RMSNorm(activations.Row(token_idx), weights_t->PackedScale1(), 0, + out.Row(token_idx), activations.Cols()); + } + }); } -// TODO: pass RowVectorBatch argument. -template -void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, - InOutT* inout, const size_t model_dim) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); - } +template +void RMSNormInplaceBatched(const MatPtr& weights, MatPtrT& inout) { + HWY_DASSERT(weights.Rows() == 1); + HWY_DASSERT(weights.Cols() == inout.Cols()); + + CallUpcasted(&weights, [&](const auto* weights_t) { + for (size_t token_idx = 0; token_idx < inout.Rows(); ++token_idx) { + RMSNormInplace(weights_t->PackedScale1(), 0, inout.Row(token_idx), + inout.Cols()); + } + }); } -template -void LayerNormBatched(size_t num_tokens, const VecT* x, - const WeightT* HWY_RESTRICT weight, - const WeightT* HWY_RESTRICT bias, OutT* out, - const size_t size) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - LayerNorm(x + token_idx * size, weight, bias, out + token_idx * size, size); - } +// x and out may be the same. +template +void LayerNormBatched(const MatPtrT& x, const MatPtr& weight, + const MatPtr& bias, MatPtrT& out) { + HWY_DASSERT(weight.Cols() == bias.Cols()); + HWY_DASSERT(weight.Cols() == x.Cols()); + HWY_DASSERT(x.SameShape(out)); + + CallUpcastedSame( + &weight, &bias, [&](const auto* weight_t, const auto* bias_t) { + for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) { + LayerNorm(x.Row(token_idx), weight_t->PackedScale1(), + bias_t->PackedScale1(), out.Row(token_idx), x.Cols()); + } + }); } -static HWY_INLINE void AddFromBatched(size_t num_tokens, const float* other, - float* x, const size_t model_dim) { - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, - model_dim); +static HWY_INLINE void AddFromBatched(const MatPtrT& other, + MatPtrT& x) { + HWY_DASSERT(x.SameShape(other)); + for (size_t token_idx = 0; token_idx < x.Rows(); ++token_idx) { + AddFrom(other.Row(token_idx), x.Row(token_idx), x.Cols()); } } @@ -743,8 +802,8 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( HWY_ASSERT(k != 0); HWY_ASSERT(k <= vocab_size); std::vector packed_token_probs; - for (int32_t i = 0; i < vocab_size; ++i) { - if (accept_token && !accept_token(StaticCast(i), probabilities[i])) { + for (int32_t i = 0; i < static_cast(vocab_size); ++i) { + if (accept_token && !accept_token(i, probabilities[i])) { continue; } packed_token_probs.push_back(PackTokenAndProb(i, probabilities[i])); @@ -756,7 +815,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED std::vector TopK( std::vector token_probs; token_probs.reserve(k); - for (int32_t i = 0; i < k; ++i) { + for (int32_t i = 0; i < static_cast(k); ++i) { token_probs.push_back(UnpackTokenAndProb(packed_token_probs[i])); } return token_probs; @@ -770,7 +829,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK( TopK(probabilities, vocab_size, k, accept_token); std::vector topk_indices(k); std::vector topk_probs(k); - for (int i = 0; i < k; ++i) { + for (size_t i = 0; i < k; ++i) { topk_indices[i] = token_probs[i].token; topk_probs[i] = token_probs[i].prob; } @@ -788,7 +847,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( TopK(logits, vocab_size, k, accept_token); std::vector topk_indices(k); std::vector topk_logits(k); - for (int i = 0; i < token_logits.size(); ++i) { + for (size_t i = 0; i < token_logits.size(); ++i) { topk_indices[i] = token_logits[i].token; topk_logits[i] = token_logits[i].prob; } @@ -807,20 +866,20 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( // Input has 4096 (64*64) rows, output has 256 (16*16) rows // Each output row is the average of a 4x4 block of input rows template -RowVectorBatch AvgPool4x4(RowVectorBatch& input) { - const Allocator& allocator = ThreadingContext::Get().allocator; +MatStorageT AvgPool4x4(MatStorageT& input) { const Extents2D extents = input.Extents(); // Input validation HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows // Create output with 256 rows and same number of columns const size_t out_rows = 256; // 16 * 16 = 256 output rows - RowVectorBatch result(allocator, Extents2D(out_rows, extents.cols)); + MatStorageT result("pool4x4", Extents2D(out_rows, extents.cols), + MatPadding::kOdd); const size_t input_dim = 64; // Input is 64×64 const size_t output_dim = 16; // Output is 16×16 for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) { for (size_t out_col_idx = 0; out_col_idx < output_dim; ++out_col_idx) { size_t out_idx = out_row_idx * output_dim + out_col_idx; - T* output_row = result.Batch(out_idx); + T* output_row = result.Row(out_idx); // Initialize output row to zeros std::fill(output_row, output_row + extents.cols, 0); // Average 16 row vectors from a 4x4 block @@ -829,9 +888,9 @@ RowVectorBatch AvgPool4x4(RowVectorBatch& input) { size_t in_row_idx = out_row_idx * 4 + i; size_t in_col_idx = out_col_idx * 4 + j; size_t in_idx = in_row_idx * input_dim + in_col_idx; - const T* input_row = input.Batch(in_idx); + const T* input_row = input.Row(in_idx); // Add each input row to the output - // TODO(philculliton): use AddFrom in ops-inl for a vectorized loop. + // TODO(philculliton): use AddFrom in `ops-inl` for a vectorized loop. for (size_t col = 0; col < extents.cols; ++col) { output_row[col] += input_row[col]; } diff --git a/ops/ops.h b/ops/ops.h index 4e733cd..19f6daf 100644 --- a/ops/ops.h +++ b/ops/ops.h @@ -20,23 +20,22 @@ #include -#include "util/allocator.h" #include "util/mat.h" #include "hwy/base.h" namespace gcpp { -static inline HWY_MAYBE_UNUSED RowVectorBatch CreateInvTimescale( +static inline HWY_MAYBE_UNUSED MatStorageT CreateInvTimescale( const Allocator& allocator, size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) { const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; - RowVectorBatch inv_timescale(allocator, Extents2D(1, rope_dim / 2)); + MatStorageT inv_timescale("inv_timescale", rope_dim / 2); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { const double freq_exponents = static_cast(2 * dim) / static_cast(rope_dim); // Replacing with expf(ln(1E4) * freq_exponents) changes results // noticeably. - inv_timescale.Batch(0)[dim] = + inv_timescale.Packed()[dim] = static_cast(1.0 / std::pow(base_frequency, freq_exponents)); } return inv_timescale; diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 53913ff..cf88fd8 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -34,7 +34,7 @@ #include "gemma/common.h" // ChooseQueryScale #include "util/allocator.h" #include "util/basics.h" // BF16 -#include "util/mat.h" // RowVectorBatch +#include "util/mat.h" // MatStorageT #include "util/test_util.h" #include "util/threading_context.h" #include "hwy/tests/hwy_gtest.h" @@ -391,7 +391,7 @@ void TestRopeAndMulBy() { ModelConfig config(Model::GEMMA2_9B, Type::kSFP, ChooseWrapping(Model::GEMMA2_9B)); int dim_qkv = config.layer_configs[0].qkv_dim; - RowVectorBatch x(allocator, Extents2D(1, dim_qkv)); + MatStorageT x("x", dim_qkv); std::mt19937 gen; gen.seed(0x12345678); @@ -399,43 +399,43 @@ void TestRopeAndMulBy() { auto random_float = [&r, &gen] { return r(gen); }; for (int i = 0; i < dim_qkv; ++i) { - x.All()[i] = random_float(); + x.Packed()[i] = random_float(); } const float qmul = ChooseQueryScale(config); const float kmul = 1.0; - std::vector qexpected(dim_qkv); - std::vector qactual(dim_qkv); - std::vector kexpected(dim_qkv); - std::vector kactual(dim_qkv); - RowVectorBatch inv_timescale = CreateInvTimescale( + MatStorageT qexpected("qexpected", dim_qkv); + MatStorageT qactual("qactual", dim_qkv); + MatStorageT kexpected("kexpected", dim_qkv); + MatStorageT kactual("kactual", dim_qkv); + MatStorageT inv_timescale = CreateInvTimescale( allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); // Assert VectorizedRope computation is same as regular rope at different pos. for (int pos = 1; pos < 500; pos++) { // Rope'd Q embeddings - hwy::CopyBytes(x.Const(), qactual.data(), dim_qkv); - hwy::CopyBytes(x.Const(), qexpected.data(), dim_qkv); - ScalarRopeAndMulBy(qmul, qexpected.data(), dim_qkv, inv_timescale.Const(), - pos); - RopeAndMulBy(qmul, qactual.data(), dim_qkv, inv_timescale.Const(), pos); + CopyMat(x, qactual); + CopyMat(x, qexpected); + ScalarRopeAndMulBy(qmul, qexpected.Packed(), dim_qkv, + inv_timescale.Packed(), pos); + RopeAndMulBy(qmul, qactual.Packed(), dim_qkv, inv_timescale.Packed(), pos); for (int i = 0; i < dim_qkv; ++i) { - EXPECT_NEAR(qactual[i], qexpected[i], 1e-4) - << "qIndex:" << i << "qInput:" << qactual[i]; + EXPECT_NEAR(qactual.Packed()[i], qexpected.Packed()[i], 1e-4) + << "qIndex:" << i << "qInput:" << qactual.Packed()[i]; } // Rope'd K embeddings - hwy::CopyBytes(x.Const(), kactual.data(), dim_qkv); - hwy::CopyBytes(x.Const(), kexpected.data(), dim_qkv); - ScalarRopeAndMulBy(kmul, kexpected.data(), dim_qkv, inv_timescale.Const(), - pos); - RopeAndMulBy(kmul, kactual.data(), dim_qkv, inv_timescale.Const(), pos); + CopyMat(x, kactual); + CopyMat(x, kexpected); + ScalarRopeAndMulBy(kmul, kexpected.Packed(), dim_qkv, + inv_timescale.Packed(), pos); + RopeAndMulBy(kmul, kactual.Packed(), dim_qkv, inv_timescale.Packed(), pos); for (int i = 0; i < dim_qkv; ++i) { - EXPECT_NEAR(kactual[i], kexpected[i], 1e-4) - << "kIndex:" << i << "kInput:" << kactual[i]; + EXPECT_NEAR(kactual.Packed()[i], kexpected.Packed()[i], 1e-4) + << "kIndex:" << i << "kInput:" << kactual.Packed()[i]; } } } @@ -451,10 +451,9 @@ HWY_NOINLINE float ScalarSquaredL2(const T* HWY_RESTRICT a, size_t size) { } // Supports bf16 and f32 inputs/outputs, which can be in-place. -template -HWY_NOINLINE void ScalarRMSNorm(const VecT* x, - const WeightT* HWY_RESTRICT weight, OutT* out, - size_t size) { +template +HWY_NOINLINE void ScalarRMSNorm(const XT* x, const WT* HWY_RESTRICT weight, + OT* out, size_t size) { constexpr float kEps = 1e-6f; float ss = ScalarSquaredL2(x, size); ss = 1.0f / sqrtf(ss / StaticCast(size) + kEps); @@ -462,32 +461,32 @@ HWY_NOINLINE void ScalarRMSNorm(const VecT* x, const float v = hwy::ConvertScalarTo(x[j]); const float w = hwy::ConvertScalarTo(weight[j]); // Note 1.0f centering here - out[j] = hwy::ConvertScalarTo((1.0f + w) * (ss * v)); + out[j] = hwy::ConvertScalarTo((1.0f + w) * (ss * v)); } } -template +template void TestRMSNorm(hwy::RandomState& rng) { constexpr size_t kSize = 128; - HWY_ALIGN VecT vec[kSize]; - HWY_ALIGN WeightT weight[kSize]; - HWY_ALIGN OutT expected[kSize]; - HWY_ALIGN OutT actual[kSize]; + HWY_ALIGN XT vec[kSize]; + HWY_ALIGN WT weight[kSize]; + HWY_ALIGN OT expected[kSize]; + HWY_ALIGN OT actual[kSize]; for (size_t i = 0; i < kSize; ++i) { - vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); } ScalarRMSNorm(vec, weight, expected, kSize); - RMSNorm(vec, weight, actual, kSize); + RMSNorm(vec, weight, 0, actual, kSize); for (size_t i = 0; i < kSize; i++) { const float e = hwy::ConvertScalarTo(expected[i]); const float a = hwy::ConvertScalarTo(actual[i]); if (!IsNear(e, a, 1e-5f)) { - HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), - TypeName(), TypeName(), i, e, a); + HWY_ABORT("RMSNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), + TypeName(), TypeName(), i, e, a); } } } @@ -526,24 +525,64 @@ void TestLayerNormSimple() { } } -// Note: there is no vectorized implementation of LayerNorm yet. So this test -// currently only checks that the scalar version can be called for the below -// combinations of float/BF16 inputs and outputs. -template +// Computes mean mu and mean of squares mu2 of a vector. Used in +// ScalarLayerNorm. +template +HWY_NOINLINE void ScalarMus(const T* HWY_RESTRICT a, size_t size, double& mu, + double& mu2) { + HWY_ASSERT(size > 0); + double sum = 0.0; + double sum2 = 0.0; + for (size_t i = 0; i < size; ++i) { + const float f = hwy::ConvertScalarTo(a[i]); + sum += f; + sum2 += f * f; + } + mu = sum / size; + mu2 = sum2 / size; +} + +// Compare py/flax/linen/normalization.py. +// out = (x - mean) * scale * rsqrt(var + epsilon) + bias +template +HWY_NOINLINE void ScalarLayerNorm(const XT* x, const WT* HWY_RESTRICT scale, + const WT* HWY_RESTRICT bias, OT* out, + size_t size) { + constexpr double kEps = 1e-6; + double mu, mu2; + ScalarMus(x, size, mu, mu2); + double var = mu2 - mu * mu; + constexpr double kZero = 0.0; + var = HWY_MAX(var, kZero); + var = 1.0 / sqrt(var + kEps); + for (size_t j = 0; j < size; j++) { + const float v = hwy::ConvertScalarTo(x[j]); + const float s = hwy::ConvertScalarTo(scale[j]); + const float b = hwy::ConvertScalarTo(bias[j]); + out[j] = hwy::ConvertScalarTo((v - mu) * s * var + b); + } +} + +template void TestLayerNorm(hwy::RandomState& rng) { constexpr size_t kSize = 128; - VecT vec[kSize]; - WeightT weight[kSize]; - WeightT bias[kSize]; - OutT expected[kSize]; - OutT actual[kSize]; + XT vec[kSize]; + WT weight[kSize]; + WT bias[kSize]; + OT expected[kSize]; + OT actual[kSize]; for (size_t i = 0; i < kSize; ++i) { - vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); - bias[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + vec[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + weight[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); + bias[i] = hwy::ConvertScalarTo(RandomGaussian(rng)); } + double expected_mu, expected_mu2; + ScalarMus(vec, kSize, expected_mu, expected_mu2); + double actual_mu, actual_mu2; + ComputeMoments(vec, kSize, actual_mu, actual_mu2); + ScalarLayerNorm(vec, weight, bias, expected, kSize); LayerNorm(vec, weight, bias, actual, kSize); @@ -551,8 +590,8 @@ void TestLayerNorm(hwy::RandomState& rng) { const float e = hwy::ConvertScalarTo(expected[i]); const float a = hwy::ConvertScalarTo(actual[i]); if (!IsNear(e, a, 1e-5f)) { - HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), - TypeName(), TypeName(), i, e, a); + HWY_ABORT("LayerNorm %s %s %s mismatch at %zu: %E %E\n", TypeName(), + TypeName(), TypeName(), i, e, a); } } } diff --git a/python/configs.cc b/python/configs.cc index 96fe736..0bf69aa 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -55,8 +55,7 @@ PYBIND11_MODULE(configs, py_module) { .value("kSFP", Type::kSFP) .value("kNUQ", Type::kNUQ) .value("kF64", Type::kF64) - .value("kC64", Type::kC64) - .value("kU128", Type::kU128); + .value("kC64", Type::kC64); enum_(py_module, "LayerAttentionType") .value("kGemma", LayerAttentionType::kGemma) diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 990db58..23b9b99 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -168,9 +168,9 @@ class GemmaModel { void SetImage(const py::array_t& image) { const gcpp::Gemma& gemma = *gemma_.GetGemma(); - const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator; - if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA && - gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) { + const gcpp::ModelConfig& config = gemma.GetModelConfig(); + if (config.wrapping != gcpp::PromptWrapping::PALIGEMMA && + config.wrapping != gcpp::PromptWrapping::GEMMA_VLM) { throw std::invalid_argument("Not a PaliGemma model."); } py::buffer_info buffer = image.request(); @@ -182,14 +182,15 @@ class GemmaModel { float* ptr = static_cast(buffer.ptr); gcpp::Image c_image; c_image.Set(height, width, ptr); - const size_t image_size = gemma.GetModelConfig().vit_config.image_size; + const size_t image_size = config.vit_config.image_size; c_image.Resize(image_size, image_size); - image_tokens_ = gcpp::ImageTokens( - allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len, - gemma.GetModelConfig().model_dim)); + image_tokens_.reset(new gcpp::ImageTokens( + "image_tokens", + gcpp::Extents2D(config.vit_config.seq_len, config.model_dim), + gcpp::MatPadding::kOdd)); gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), .verbosity = 0}; - gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_); + gemma.GenerateImageTokens(runtime_config, c_image, *image_tokens_); } // Generates a response to the given prompt, using the last set image. @@ -197,9 +198,7 @@ class GemmaModel { std::pair> GenerateWithImage( std::string prompt, size_t max_generated_tokens, float temperature, float seed, gcpp::AcceptFunc accept, std::vector prompt_tokens) { - if (image_tokens_.Cols() == 0) { - throw std::invalid_argument("No image set."); - } + if (!image_tokens_) throw std::invalid_argument("No image set."); const gcpp::Gemma& model = *gemma_.GetGemma(); gemma_.MutableGen().seed(seed); gcpp::RuntimeConfig& config = gemma_.MutableConfig(); @@ -207,7 +206,7 @@ class GemmaModel { config.temperature = temperature; config.verbosity = 0; config.accept_token = accept; - config.image_tokens = &image_tokens_; + config.image_tokens = image_tokens_.get(); std::vector tokens; if (!prompt_tokens.empty()) { if (!prompt.empty()) { @@ -219,7 +218,7 @@ class GemmaModel { } else { tokens = gemma_.WrapAndTokenize(prompt); } - tokens.insert(tokens.begin(), image_tokens_.BatchSize(), 0); + tokens.insert(tokens.begin(), image_tokens_->Rows(), 0); size_t num_tokens = tokens.size(); size_t prefix_end = num_tokens; config.prefill_tbatch_size = num_tokens; @@ -252,7 +251,7 @@ class GemmaModel { private: gcpp::GemmaEnv gemma_; - gcpp::ImageTokens image_tokens_; + std::unique_ptr image_tokens_; float last_prob_; }; diff --git a/util/mat.cc b/util/mat.cc index e44e83b..5950596 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -117,11 +117,11 @@ static size_t Stride(const Allocator& allocator, const MatPtr& mat, } } -void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { - if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked; +void MatOwner::AllocateFor(MatPtr& mat, const MatPadding padding) { + const bool is_nuq = mat.GetType() == Type::kNUQ; const Allocator& allocator = ThreadingContext::Get().allocator; - const size_t stride = Stride(allocator, mat, padding); - const size_t num = mat.Rows() * stride; + const size_t stride = is_nuq ? mat.Cols() : Stride(allocator, mat, padding); + const size_t num = is_nuq ? mat.PackedBytes() : mat.Rows() * stride; // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // might not be enough, hence add extra. `MatT` is at least one byte, which // is half of BF16, hence adding `VectorBytes` *elements* is enough. diff --git a/util/mat.h b/util/mat.h index c5e7a54..7d9113a 100644 --- a/util/mat.h +++ b/util/mat.h @@ -28,7 +28,7 @@ #include "compression/shared.h" // Type #include "gemma/tensor_info.h" #include "io/fields.h" -#include "util/allocator.h" +#include "util/allocator.h" // AlignedPtr2 #include "util/basics.h" // Extents2D // IWYU pragma: end_exports #include "hwy/base.h" @@ -47,7 +47,7 @@ class MatPtr : public IFields { // `name`: see `SetName`. Note that `stride` is initially `cols` and only // differs after deserializing, or calling `SetPtr`. MatPtr(const char* name, Type type, Extents2D extents) - : rows_(static_cast(extents.rows)), + : private_rows_(static_cast(extents.rows)), cols_(static_cast(extents.cols)) { SetName(name); SetType(type); @@ -74,7 +74,7 @@ class MatPtr : public IFields { bool HasPtr() const { return ptr_ != nullptr; } // A single row counts as packed because there is no padding between rows. - bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); } + bool IsPacked() const { return (stride_ == cols_) || (Rows() == 1); } const void* Packed() const { HWY_DASSERT_M(IsPacked(), name_.c_str()); @@ -96,17 +96,17 @@ class MatPtr : public IFields { // Works for any kind of padding. template T* MutableRowT(size_t row) const { - HWY_DASSERT(row < rows_); + HWY_DASSERT(row < Rows()); return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; } template T* RowT(size_t row) { - HWY_DASSERT(row < rows_); + HWY_DASSERT(row < Rows()); return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; } template const T* RowT(size_t row) const { - HWY_DASSERT(row < rows_); + HWY_DASSERT(row < Rows()); return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_; } @@ -118,10 +118,22 @@ class MatPtr : public IFields { HWY_DASSERT(0 != element_bytes_ && element_bytes_ <= 16); } - bool IsEmpty() const { return rows_ == 0 || cols_ == 0; } - size_t Rows() const { return rows_; } + size_t Rows() const { + return override_rows_ == 0 ? private_rows_ : override_rows_; + } size_t Cols() const { return cols_; } - Extents2D Extents() const { return Extents2D(rows_, cols_); } + Extents2D Extents() const { return Extents2D(Rows(), cols_); } + bool IsEmpty() const { return Rows() == 0 || cols_ == 0; } + bool SameShape(const MatPtr& other) const { + return Rows() == other.Rows() && cols_ == other.cols_; + } + // Future calls to `Rows()` during this class' lifetime (not serialized) + // will return this value. Used to set the actual number of rows for + // activations preallocated according to the batch size. + void OverrideRows(size_t rows) { + HWY_ASSERT(rows <= private_rows_); + override_rows_ = static_cast(rows); + } // Offset by which to advance pointers to the next row. size_t Stride() const { return stride_; } @@ -150,7 +162,7 @@ class MatPtr : public IFields { visitor(type_); visitor(element_bytes_); visitor(num_elements_); - visitor(rows_); + visitor(private_rows_); visitor(cols_); visitor(scale_); visitor(stride_); @@ -164,11 +176,11 @@ class MatPtr : public IFields { // padding, which is anyway not supported for NUQ because `compress-inl.h` // assumes a contiguous stream for its group indexing. static size_t ComputeNumElements(Type type, Extents2D extents) { - const size_t num_elements = extents.Area(); + size_t num_elements = extents.Area(); if (type == Type::kNUQ) { // `CompressedArrayElements` is a wrapper function that has the same // effect, but that requires a template argument, not `type`. - return NuqStream::PackedEnd(num_elements); + num_elements = NuqStream::PackedEnd(num_elements); } return num_elements; } @@ -184,9 +196,10 @@ class MatPtr : public IFields { // Number of elements to store (including NUQ tables but not padding). // This a function of `type_` and `Extents()` and stored for compatibility. uint32_t num_elements_ = 0; - uint32_t rows_ = 0; + uint32_t private_rows_ = 0; // Only access via Rows()! See OverrideRows(). uint32_t cols_ = 0; - float scale_ = 1.0f; // multiplier for each value, for MatMul. + + uint32_t override_rows_ = 0; // not serialized // Non-owning pointer, must not be freed. The underlying memory must outlive // this object. @@ -194,6 +207,8 @@ class MatPtr : public IFields { // Offset by which to advance pointers to the next row, >= `cols_`. uint32_t stride_; + + float scale_ = 1.0f; // multiplier for each value, for MatMul. }; // Non-type erased version of `MatPtr`. Although `MatPtr` also provides @@ -202,6 +217,8 @@ class MatPtr : public IFields { template class MatPtrT : public MatPtr { public: + using T = MatT; + // Called by `MatStorageT`. MatPtrT(const char* name, Extents2D extents) : MatPtr(name, TypeEnum(), extents) {} @@ -253,26 +270,67 @@ class MatPtrT : public MatPtr { }; // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT`, plus the -// optional `args`. Currently unused but may be used after we move toward -// type-erased `WeightsPtrs`. +// optional `args`. This supports all types used as weights, which excludes +// `kC64` and `kF64` (used only in `backprop/`). template -decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func, +decltype(auto) CallUpcasted(const MatPtr* base, const Func& func, Args&&... args) { - HWY_ASSERT(base != nullptr); - if (type == Type::kF32) { - return func(dynamic_cast*>(base), + if (base->GetType() == Type::kF32) { + return func(dynamic_cast*>(base), std::forward(args)...); - } else if (type == Type::kBF16) { - return func(dynamic_cast*>(base), + } else if (base->GetType() == Type::kBF16) { + return func(dynamic_cast*>(base), std::forward(args)...); - } else if (type == Type::kSFP) { - return func(dynamic_cast*>(base), + } else if (base->GetType() == Type::kSFP) { + return func(dynamic_cast*>(base), std::forward(args)...); - } else if (type == Type::kNUQ) { - return func(dynamic_cast*>(base), + } else if (base->GetType() == Type::kNUQ) { + return func(dynamic_cast*>(base), std::forward(args)...); } else { - HWY_ABORT("Type %d unknown.", static_cast(type)); + HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); + } +} + +// Calls `func(base1, base2, args...)`. +template +decltype(auto) CallUpcastedSame(const MatPtr* base1, const MatPtr* base2, + const Func& func, Args&&... args) { + HWY_ASSERT(base1->GetType() == base2->GetType()); + if (base1->GetType() == Type::kF32) { + return func(dynamic_cast*>(base1), + dynamic_cast*>(base2), + std::forward(args)...); + } else if (base1->GetType() == Type::kBF16) { + return func(dynamic_cast*>(base1), + dynamic_cast*>(base2), + std::forward(args)...); + } else if (base1->GetType() == Type::kSFP) { + return func(dynamic_cast*>(base1), + dynamic_cast*>(base2), + std::forward(args)...); + } else if (base1->GetType() == Type::kNUQ) { + return func(dynamic_cast*>(base1), + dynamic_cast*>(base2), + std::forward(args)...); + } else { + HWY_ABORT("Unhandled type %s.", TypeName(base1->GetType())); + } +} + +// Like CallUpcasted, but only for activation types: kBF16 and kF32. +template +decltype(auto) CallUpcastedActivation(const MatPtr* base, const Func& func, + Args&&... args) { + HWY_ASSERT(base != nullptr); + if (base->GetType() == Type::kF32) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } else if (base->GetType() == Type::kBF16) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } else { + HWY_ABORT("Unhandled type %s.", TypeName(base->GetType())); } } @@ -362,8 +420,11 @@ class MatStorageT : public MatPtrT { public: MatStorageT(const char* name, Extents2D extents, MatPadding padding) : MatPtrT(name, extents) { - owner_.AllocateFor(*this, padding); + if (extents.Area() != 0) owner_.AllocateFor(*this, padding); } + // Shorthand for 1D tensors: packing does not help, hence `kPacked`. + MatStorageT(const char* name, size_t cols) + : MatStorageT(name, Extents2D(1, cols), MatPadding::kPacked) {} ~MatStorageT() = default; // Allow move for backprop/activations. @@ -467,81 +528,14 @@ using RowPtrBF = RowPtr; using RowPtrF = RowPtr; using RowPtrD = RowPtr; -// Owns dynamically-allocated aligned memory for a batch of row vectors. -// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns -// the memory. Unlike `MatPtr`, this lacks metadata. -// TODO: replace with `MatStorageT`. +// TODO: remove allocator arg once kCyclic is removed. template -class RowVectorBatch { - public: - // Default ctor for Activations ctor. - RowVectorBatch() = default; - // Main ctor, called from Activations::Allocate. If `stride` = 0, the default, - // we default to tightly packed rows (`stride = cols`). - // WARNING: not all call sites support `stride` != cols. - // TODO: once they do, remove stride and behave like AllocateAlignedRows here. - RowVectorBatch(const Allocator& allocator, Extents2D extents, - size_t stride = 0) - : extents_(extents) { - if (stride == 0) { - stride_ = extents_.cols; - } else { - HWY_ASSERT(stride >= extents_.cols); - stride_ = stride; - } - // Allow binding the entire matrix. - const size_t padded = hwy::RoundUpTo(extents_.rows * stride_, - allocator.QuantumBytes() / sizeof(T)); - mem_ = allocator.Alloc(padded); - } - - // Move-only - RowVectorBatch(RowVectorBatch&) noexcept = delete; - RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; - RowVectorBatch(RowVectorBatch&&) noexcept = default; - RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; - - size_t BatchSize() const { return extents_.rows; } - size_t Cols() const { return extents_.cols; } - size_t Stride() const { return stride_; } - Extents2D Extents() const { return extents_; } - - // Returns the given row vector of length `Cols()`. - T* Batch(size_t batch_idx) { - HWY_DASSERT(batch_idx < BatchSize()); - return mem_.get() + batch_idx * stride_; - } - const T* Batch(size_t batch_idx) const { - HWY_DASSERT(batch_idx < BatchSize()); - return mem_.get() + batch_idx * stride_; - } - - // For MatMul or other operations that process the entire batch at once. - // TODO: remove once we only use Mat. - T* All() { return mem_.get(); } - const T* Const() const { return mem_.get(); } - size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); } - - private: - AlignedPtr2 mem_; - Extents2D extents_; - size_t stride_; -}; - -template -RowPtr RowPtrFromBatch(const Allocator& allocator, - RowVectorBatch& row_vectors) { - return RowPtr(allocator, row_vectors.All(), row_vectors.Cols(), - row_vectors.Stride()); -} - -template -RowVectorBatch AllocateAlignedRows(const Allocator& allocator, - Extents2D extents) { - return RowVectorBatch( - allocator, extents, - StrideForCyclicOffsets(extents.cols, - allocator.QuantumBytes() / sizeof(T))); +RowPtr RowPtrFromMat(const Allocator& allocator, + const MatPtrT& row_vectors) { + // RowPtr is non-const for MatMul C, but is also used for A which is const. + // Callers are responsible for checking their usage of RowPtr. + return RowPtr(allocator, const_cast(row_vectors.Row(0)), + row_vectors.Cols(), row_vectors.Stride()); } } // namespace gcpp