From 2ee1fac74ce1a32aff5f84e108fd3ec5d679181c Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Wed, 7 Jan 2026 01:21:02 -0800 Subject: [PATCH 1/7] Internal changes PiperOrigin-RevId: 853138600 --- BUILD.bazel | 2 ++ gemma/flash_attention.cc | 4 ++++ util/test_util.h | 2 +- 3 files changed, 7 insertions(+), 1 deletion(-) diff --git a/BUILD.bazel b/BUILD.bazel index 491939a..130f18f 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -608,7 +608,9 @@ cc_library( ], deps = [ ":activations", + ":basics", ":configs", + ":kv_cache", ":mat", ":matmul", ":matmul_env", diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index 2129a63..ba985fc 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -22,10 +22,14 @@ #include #include #include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS #include "gemma/flash_structs.h" +#include "gemma/kv_cache.h" +#include "gemma/query.h" +#include "util/basics.h" #include "util/threading_context.h" #include "util/zones.h" #include "hwy/base.h" diff --git a/util/test_util.h b/util/test_util.h index f0c37f9..19342e4 100644 --- a/util/test_util.h +++ b/util/test_util.h @@ -115,7 +115,7 @@ template void PrintMatPtr(MatPtrT mat) { for (int i = 0; i < mat.Rows(); ++i) { for (int j = 0; j < mat.Cols(); ++j) { - std::cerr << mat.Row(i)[j] << " ,"; + std::cerr << hwy::ConvertScalarTo(mat.Row(i)[j]) << " ,"; } std::cerr << std::endl; } From aeade052c600ae55dcd038b2802439f5947ff964 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 7 Jan 2026 10:32:44 -0800 Subject: [PATCH 2/7] Move AssertClose to test_util, add U16 PiperOrigin-RevId: 853321311 --- compression/test_util-inl.h | 124 ++++++++++++++++++++++++++++++++++++ compression/types.h | 20 +++++- ops/matmul_test.cc | 120 +--------------------------------- paligemma/BUILD.bazel | 1 - 4 files changed, 143 insertions(+), 122 deletions(-) diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index f2c8b8c..99b34b5 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -17,6 +17,10 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ +#include + +#include + // IWYU pragma: begin_exports #include "compression/distortion.h" #include "util/mat.h" @@ -153,6 +157,126 @@ MatStorageT GenerateTransposedMat(const Extents2D extents, return compressed; } +// Returns 1-norm, used for estimating tolerable numerical differences. +inline double MaxRowAbsSum(const MatStorageT& a) { + double max_row_abs_sum = 0.0; + for (size_t r = 0; r < a.Rows(); r++) { + const float* row = a.Row(r); + double row_abs_sum = 0.0; + for (size_t c = 0; c < a.Cols(); c++) { + row_abs_sum += hwy::ScalarAbs(row[c]); + } + max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum); + } + return max_row_abs_sum; +} + +// Returns the maximum absolute value of `a`. +inline float MaxAbs(const MatStorageT& a) { + float max_abs = 0.0f; + for (size_t c = 0; c < a.Cols(); c++) { + for (size_t r = 0; r < a.Rows(); r++) { + const float* row = a.Row(r); + max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c])); + } + } + return max_abs; +} + +// B is already transposed. +template +void AssertClose(const MatPtrT& A, const MatPtrT& B, + const MatPtrT& C_slow, const MatPtrT& C, + const Allocator& allocator, + std::vector>& row_ptrs, + int line) { + const hn::ScalableTag df; + const size_t cols = A.Cols(); + const size_t B_rows = B.Rows(); + // Round up for DecompressAndZeroPad. + MatStorageT a_batch("a_batch", A.Extents(), allocator, + MatPadding::kOdd); + MatStorageT b_trans_batch("b_trans_batch", B.Extents(), allocator, + MatPadding::kOdd); + MatStorageT c_batch("c_batch", Extents2D(A.Rows(), B_rows), allocator, + MatPadding::kOdd); + c_batch.AllocateAndAttachRowPtrs(row_ptrs); + MatStorageT c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows), + allocator, MatPadding::kOdd); + for (size_t m = 0; m < A.Rows(); ++m) { + DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols); + DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m), + B_rows); + DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0, + c_slow_batch.Row(m), B_rows); + } + for (size_t n = 0; n < B_rows; ++n) { + DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n), + cols); + } + + // MatMul rounds inputs to BF16, so error is proportional to the max input + // magnitude, but also to f32 accumulation of rows in A and B. + const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch); + const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch); + const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); + const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); + // Dot() uses double-precision summation. + double tolerance = 20 * norm * eps_f32; + // If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the + // F32 to BF16, so add extra tolerance. + if (IsF32() || IsF32()) { + tolerance += 2 * max_abs * eps_bf16; + } + + if (tolerance > 500.0) { + HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); + } + const double rel_tolerance = + 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); + + double max_rel = 0.0; + size_t worst_r = 0; + size_t worst_c = 0; + double worst_actual = 0.0; + double worst_expected = 0.0; + size_t num_outside = 0; + for (size_t r = 0; r < A.Rows(); r++) { + const float* expected_row = c_slow_batch.Row(r); + const float* actual_row = c_batch.Row(r); + for (size_t c = 0; c < B.Rows(); c++) { + const double expected_value = static_cast(expected_row[c]); + const double actual_value = static_cast(actual_row[c]); + const bool in_range = expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance; + + if (!in_range) { + const double max = HWY_MAX(expected_value, actual_value); + const double min = HWY_MIN(expected_value, actual_value); + const double rel = max / HWY_MAX(min, 1E-6); + if (rel > max_rel) { + worst_expected = expected_value; + worst_actual = actual_value; + worst_r = r; + worst_c = c; + max_rel = rel; + ++num_outside; + } + } + } + } + + if (max_rel > rel_tolerance) { + hwy::Abort(__FILE__, line, + "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " + "tolerance %f rel %E max_rel %E num_outside %zu\n", + worst_r, worst_c, worst_expected, worst_actual, norm, max_abs, + tolerance, max_rel, rel_tolerance, num_outside); + } + HWY_ASSERT(hn::AllFalse( + df, hn::IsEitherNaN(hn::Set(df, norm), hn::Set(df, max_abs)))); +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/compression/types.h b/compression/types.h index 8f11591..6e6129d 100644 --- a/compression/types.h +++ b/compression/types.h @@ -218,12 +218,23 @@ constexpr bool SupportsPointerArithmetic() { } // Tensor types for loading weights. Not all of these are supported weight -// types, some are only used for `Activations`. -enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kU32, kU64, kI8 }; +// types, some are only used for `Activations`. Append-only. +enum class Type { + kUnknown, + kF32, + kBF16, + kSFP, + kNUQ, + kF64, + kU32, + kU64, + kI8, + kU16 +}; // These are used in `ModelConfig.Specifier`, hence the strings will not // change, though new ones may be added. static constexpr const char* kTypeStrings[] = { - "unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8"}; + "unknown", "f32", "bf16", "sfp", "nuq", "f64", "u32", "u64", "i8", "u16"}; static constexpr size_t kNumTypes = sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); static constexpr size_t kTypeBits[] = { @@ -236,6 +247,7 @@ static constexpr size_t kTypeBits[] = { 8 * sizeof(uint32_t), 8 * sizeof(uint64_t), 8 * sizeof(I8Stream), + 8 * sizeof(uint16_t), }; static inline bool EnumValid(Type type) { @@ -262,6 +274,8 @@ Type TypeEnum() { return Type::kU64; } else if constexpr (hwy::IsSame()) { return Type::kI8; + } else if constexpr (hwy::IsSame()) { + return Type::kU16; } else { HWY_DASSERT(false); return Type::kUnknown; diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index 4787122..a7a9862 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -58,122 +58,6 @@ extern int64_t first_target; namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -// Returns 1-norm, used for estimating tolerable numerical differences. -double MaxRowAbsSum(const MatStorageT& a) { - double max_row_abs_sum = 0.0; - for (size_t r = 0; r < a.Rows(); r++) { - const float* row = a.Row(r); - double row_abs_sum = 0.0; - for (size_t c = 0; c < a.Cols(); c++) { - row_abs_sum += hwy::ScalarAbs(row[c]); - } - max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum); - } - return max_row_abs_sum; -} - -// Returns the maximum absolute value of `a`. -float MaxAbs(const MatStorageT& a) { - float max_abs = 0.0f; - for (size_t c = 0; c < a.Cols(); c++) { - for (size_t r = 0; r < a.Rows(); r++) { - const float* row = a.Row(r); - max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c])); - } - } - return max_abs; -} - -// B is already transposed. -template -void AssertClose(const MatPtrT& A, const MatPtrT& B, - const MatPtrT& C_slow, const MatPtrT& C, - MatMulEnv& env, int line) { - const hn::ScalableTag df; - const size_t cols = A.Cols(); - const size_t B_rows = B.Rows(); - // Round up for DecompressAndZeroPad. - MatStorageT a_batch("a_batch", A.Extents(), env.ctx.allocator, - MatPadding::kOdd); - MatStorageT b_trans_batch("b_trans_batch", B.Extents(), - env.ctx.allocator, MatPadding::kOdd); - MatStorageT c_batch("c_batch", Extents2D(A.Rows(), B_rows), - env.ctx.allocator, MatPadding::kOdd); - c_batch.AllocateAndAttachRowPtrs(env.row_ptrs); - MatStorageT c_slow_batch("c_slow_batch", Extents2D(A.Rows(), B_rows), - env.ctx.allocator, MatPadding::kOdd); - for (size_t m = 0; m < A.Rows(); ++m) { - DecompressAndZeroPad(df, MakeSpan(A.Row(m), cols), 0, a_batch.Row(m), cols); - DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Row(m), - B_rows); - DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0, - c_slow_batch.Row(m), B_rows); - } - for (size_t n = 0; n < B_rows; ++n) { - DecompressAndZeroPad(df, MakeSpan(B.Row(n), cols), 0, b_trans_batch.Row(n), - cols); - } - - // MatMul rounds inputs to BF16, so error is proportional to the max input - // magnitude, but also to f32 accumulation of rows in A and B. - const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch); - const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch); - const double eps_bf16 = hwy::ConvertScalarTo(hwy::Epsilon()); - const double eps_f32 = hwy::ConvertScalarTo(hwy::Epsilon()); - // Dot() uses double-precision summation. - double tolerance = 20 * norm * eps_f32; - // If either is F32, Dot() promotes F32 or even F64, but MatMul demotes the - // F32 to BF16, so add extra tolerance. - if (IsF32() || IsF32()) { - tolerance += 2 * max_abs * eps_bf16; - } - - if (tolerance > 500.0) { - HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs); - } - const double rel_tolerance = - 1.0 + hwy::ConvertScalarTo(hwy::Epsilon()); - - double max_rel = 0.0; - size_t worst_r = 0; - size_t worst_c = 0; - double worst_actual = 0.0; - double worst_expected = 0.0; - size_t num_outside = 0; - for (size_t r = 0; r < A.Rows(); r++) { - const float* expected_row = c_slow_batch.Row(r); - const float* actual_row = c_batch.Row(r); - for (size_t c = 0; c < B.Rows(); c++) { - const double expected_value = static_cast(expected_row[c]); - const double actual_value = static_cast(actual_row[c]); - const bool in_range = expected_value - tolerance <= actual_value && - actual_value <= expected_value + tolerance; - - if (!in_range) { - const double max = HWY_MAX(expected_value, actual_value); - const double min = HWY_MIN(expected_value, actual_value); - const double rel = max / HWY_MAX(min, 1E-6); - if (rel > max_rel) { - worst_expected = expected_value; - worst_actual = actual_value; - worst_r = r; - worst_c = c; - max_rel = rel; - ++num_outside; - } - } - } - } - - if (max_rel > rel_tolerance) { - hwy::Abort(__FILE__, line, - "(%zu,%zu): expected %f, actual %f, norm %f maxabs %f " - "tolerance %f rel %E max_rel %E num_outside %zu\n", - worst_r, worst_c, worst_expected, worst_actual, norm, max_abs, - tolerance, max_rel, rel_tolerance, num_outside); - } -} - // B is already transposed. template HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, @@ -257,7 +141,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MMOptions options; for (size_t rep = 0; rep < 16; ++rep) { MMPerKey* per_key = MatMulStatic(A, BT, add_row, env, C, options); - AssertClose(A, BT, C_slow, C, env, line); + AssertClose(A, BT, C_slow, C, env.ctx.allocator, env.row_ptrs, line); // Check before TwoMatMulStatic(), which can invalidate per_key. const bool autotune_done = !!per_key->autotune.Best(); @@ -295,7 +179,7 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, // TwoMatMulStatic() does not support adding a bias vector. if (!add) { - AssertClose(A, BT, C, C2, env, line); + AssertClose(A, BT, C, C2, env.ctx.allocator, env.row_ptrs, line); } } diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index cc6c6e1..b749e05 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -65,7 +65,6 @@ cc_test( "//:benchmark_helper", "//:configs", "//:gemma_lib", - "//io", "@highway//:hwy_test_util", ], ) From 384c3901818ccbd0877a62d7b1bc4922a70aa668 Mon Sep 17 00:00:00 2001 From: Balazs Racz Date: Thu, 8 Jan 2026 04:28:32 -0800 Subject: [PATCH 3/7] Allow overriding hardcoded max_seq_len by cmdline argument seq_len. Adds a SetMaxSeqLen method to ModelConfig to handle updating both max_seq_len and global attention window sizes. The Gemma constructor now checks if the provided inference seq_len exceeds the model's max_seq_len and, if so, emits a warning and updates the config. This prevents clipping context to the hard-coded maximum. PiperOrigin-RevId: 853676074 --- gemma/configs.h | 14 +++++++++++++- gemma/gemma.cc | 17 +++++++++++++---- gemma/model_store.cc | 2 +- gemma/model_store.h | 7 +++++++ 4 files changed, 34 insertions(+), 6 deletions(-) diff --git a/gemma/configs.h b/gemma/configs.h index b727480..f1bd0c5 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -428,6 +428,18 @@ struct ModelConfig : public IFields { // The third ctor also expects a string returned by this. std::string Specifier() const; + // Overwrites `max_seq_len` with `new_max_seq_len` and updates all global + // layers' attention window sizes to `new_max_seq_len`. This function must be + // called before instantiating the KVCache object. + void SetMaxSeqLen(size_t new_max_seq_len) { + for (size_t i = 0; i < attention_window_sizes.size(); ++i) { + if (attention_window_sizes[i] == max_seq_len) { + attention_window_sizes[i] = new_max_seq_len; + } + } + max_seq_len = new_max_seq_len; + } + void AddLayerConfig(const LayerConfig& layer_config) { layer_configs.push_back(layer_config); HWY_ASSERT(layer_configs.size() <= num_layers); @@ -516,7 +528,7 @@ ModelConfig GetVitConfig(const ModelConfig& config); enum DeducedLayerTypes { kDeducedViT = 2, - kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. + kDeduced448 = 4, // For ViT, 448x448 resolution instead of 224x224. kDeducedKqNorm = 8, }; diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 0ce6ab3..5a48d00 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -18,6 +18,9 @@ #include "gemma/gemma.h" +#include +#include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS @@ -556,10 +559,10 @@ static size_t PrefillTBatchOrQBatch(const ModelConfig& config, } static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config, - const RuntimeConfig& runtime_config, - QBatch& qbatch, - hwy::BitSet4096<>& non_eos, - size_t qi) { + const RuntimeConfig& runtime_config, + QBatch& qbatch, + hwy::BitSet4096<>& non_eos, + size_t qi) { const size_t last_pos_in_prompt = qbatch.Pos(qi) - qbatch.InitialPos(qi); const size_t pos = qbatch.Pos(qi); // during prefill, pos is still correct. @@ -745,6 +748,12 @@ Gemma::Gemma(const GemmaArgs& args, ThreadingContext& ctx) chat_template_(model_.Tokenizer(), model_.Config().model), inference_(args.inference), aes_ctr_engine_(args.inference.deterministic) { + if (args.inference.seq_len > model_.Config().max_seq_len) { + HWY_WARN( + "Overriding model's max_seq_len=%u with user provided seq_len=%zu.", + model_.Config().max_seq_len, args.inference.seq_len); + model_.MutableConfig().SetMaxSeqLen(args.inference.seq_len); + } // Negligible CPU time in the ctor body (except ReadFromBlobs). weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, args.loader, args.inference, mat_owners_, ctx); diff --git a/gemma/model_store.cc b/gemma/model_store.cc index 204dee9..76f0c75 100644 --- a/gemma/model_store.cc +++ b/gemma/model_store.cc @@ -387,7 +387,7 @@ void ModelStore::CreateMatPtrs(BlobReader& reader) { } ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path, - Tristate wrapping) + Tristate wrapping) : config_(ReadOrDeduceConfig(reader, wrapping)), tokenizer_(ReadTokenizer(reader, tokenizer_path)) { if (!ReadMatPtrs(reader)) { // Pre-2025 format. diff --git a/gemma/model_store.h b/gemma/model_store.h index b4d63ad..506fb77 100644 --- a/gemma/model_store.h +++ b/gemma/model_store.h @@ -39,6 +39,8 @@ namespace gcpp { +class Gemma; + // Reads and holds the model config, tokenizer and all `MatPtr`: everything // except the tensor data, which are read/written by `weights.cc`. // @@ -60,6 +62,11 @@ class ModelStore { return config_; } + ModelConfig& MutableConfig() { + HWY_ASSERT(config_.model != Model::UNKNOWN); + return config_; + } + const GemmaTokenizer& Tokenizer() const { return tokenizer_; } // Returns nullptr if `name` is not available for loading, otherwise the From 42e9cf557d645ba49d7e8104222786a7b91e9108 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 8 Jan 2026 05:25:54 -0800 Subject: [PATCH 4/7] Internal change / remove unused PrintSpeed PiperOrigin-RevId: 853694463 --- ops/bench_matmul.cc | 4 ++-- ops/matmul_test.cc | 8 -------- util/zones.h | 2 +- 3 files changed, 3 insertions(+), 11 deletions(-) diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 221405c..67c702f 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -53,7 +53,7 @@ extern int64_t first_target; namespace HWY_NAMESPACE { void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, - std::vector& times, MMPerKey* per_key) { + std::vector& times) { std::sort(times.begin(), times.end()); // bench_dnn reports the best and average, but the median seems more // consistent and resistant to outliers. @@ -134,7 +134,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { } hwy::PreventElision(keep); env.ctx.pools.MaybeStopSpinning(use_spinning); - PrintSpeed(A_extents, B_extents, times, per_key); + PrintSpeed(A_extents, B_extents, times); } using F32 = float; diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index a7a9862..5ae9e36 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -95,14 +95,6 @@ HWY_INLINE void MatMulSlow(const MatPtrT A, const MatPtrT B, }); } -void PrintSpeed(const char* algo, const Extents2D& A_extents, - const Extents2D& B_extents, double elapsed) { - const size_t num_b = B_extents.Area(); - // 2x because of FMA. - fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo, - elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed); -} - template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env, int line) { diff --git a/util/zones.h b/util/zones.h index f324086..ac96ad0 100644 --- a/util/zones.h +++ b/util/zones.h @@ -10,7 +10,7 @@ namespace gcpp { // Zones for the profiler. -enum class Zones { // Keep sorted +enum class Zones { // Keep sorted kFlashAttentionFlashAttention, kFlashAttentionInclusive, kFlashAttentionRmsNormAndPositionalEncoding, From 95592a574e2ee38d5436039030228c9d8d2b31d6 Mon Sep 17 00:00:00 2001 From: "The gemma.cpp Authors" Date: Thu, 8 Jan 2026 13:29:15 -0800 Subject: [PATCH 5/7] Build fix for Arm SVE (explicit namespace qualification) PiperOrigin-RevId: 853864585 --- compression/int-inl.h | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/compression/int-inl.h b/compression/int-inl.h index 81f1caf..7500030 100644 --- a/compression/int-inl.h +++ b/compression/int-inl.h @@ -293,11 +293,11 @@ class IntCodec { for (; i + 2 * N <= g_num; i += 2 * N) { const VI8 val0 = hn::DemoteTo( di8, - hn::DemoteTo(di16, NearestInt(hn::MulAdd( + hn::DemoteTo(di16, hn::NearestInt(hn::MulAdd( mul, hn::LoadU(df, raw + i + 0 * N), add)))); const VI8 val1 = hn::DemoteTo( di8, - hn::DemoteTo(di16, NearestInt(hn::MulAdd( + hn::DemoteTo(di16, hn::NearestInt(hn::MulAdd( mul, hn::LoadU(df, raw + i + 1 * N), add)))); hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i + 0 * N); @@ -311,7 +311,7 @@ class IntCodec { if (remaining > N) { const VI8 val0 = hn::DemoteTo( - di8, hn::DemoteTo(di16, NearestInt(hn::MulAdd( + di8, hn::DemoteTo(di16, hn::NearestInt(hn::MulAdd( mul, hn::LoadU(df, raw + i), add)))); hn::StoreU(val0, di8, &packed.ptr->i + current_offset + i); @@ -319,14 +319,14 @@ class IntCodec { const VI8 val1 = hn::DemoteTo( di8, hn::DemoteTo(di16, - NearestInt(hn::MulAdd( + hn::NearestInt(hn::MulAdd( mul, hn::LoadN(df, raw + i + N, remaining1), add)))); hn::StoreN(val1, di8, &packed.ptr->i + current_offset + i + N, remaining1); } else { // remaining <= N const VI8 val0 = hn::DemoteTo( di8, hn::DemoteTo(di16, - NearestInt(hn::MulAdd( + hn::NearestInt(hn::MulAdd( mul, hn::LoadN(df, raw + i, remaining), add)))); hn::StoreN(val0, di8, &packed.ptr->i + current_offset + i, remaining); } From 6d43d6ee192f6f5335c9a35bb73eff672a3f604b Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Fri, 9 Jan 2026 02:55:28 -0800 Subject: [PATCH 6/7] Build fix for Arm SVE (invalid template arg on op) PiperOrigin-RevId: 854110884 --- gemma/flash_attention.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index ba985fc..e018ab8 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -452,7 +452,7 @@ static VF4 HWY_INLINE Reduce4(DF df, VF x_0, VF x_1, VF x_2, VF x_3, constexpr size_t kMaxLanes = hn::MaxLanes(df); HWY_LANES_CONSTEXPR size_t kLanes = hn::Lanes(df); HWY_ALIGN T x_transposed[4 * kMaxLanes]; - hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed); + hn::StoreInterleaved4(x_0, x_1, x_2, x_3, df, x_transposed); VF4 result = hn::Load(df4, x_transposed); for (int i = 1; i < kLanes; ++i) { result = reducer(result, hn::Load(df4, x_transposed + i * 4)); From 16a7ba2d6e517d7747e9daa98501e2df38788cf6 Mon Sep 17 00:00:00 2001 From: Krzysztof Rymski Date: Fri, 9 Jan 2026 06:35:05 -0800 Subject: [PATCH 7/7] Internal changes PiperOrigin-RevId: 854171429 --- BUILD.bazel | 2 ++ gemma/activations.h | 1 + gemma/gemma_args.h | 2 ++ gemma/kv_cache.cc | 4 ++++ gemma/kv_cache.h | 2 ++ gemma/kv_cache_test.cc | 9 +++++++-- util/mat.h | 32 ++++++++++++++++++++++++++++++++ 7 files changed, 50 insertions(+), 2 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 130f18f..9eb60a4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -523,6 +523,7 @@ cc_library( ":configs", ":gemma_args", ":mat", + "//compression:types", "@highway//:hwy", ], ) @@ -575,6 +576,7 @@ cc_library( ":configs", ":mat", ":threading_context", + "//compression:types", "//io", "@highway//:hwy", "@highway//:profiler", diff --git a/gemma/activations.h b/gemma/activations.h index 11e2b1c..c1b943e 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -35,6 +35,7 @@ namespace gcpp { typedef std::vector> AlignedFloatVector; +typedef std::vector> AlignedBF16Vector; // Returns the scale value to use for the query in the attention computation. // Also called by ops_test. diff --git a/gemma/gemma_args.h b/gemma/gemma_args.h index ba72db6..6ccb5b3 100644 --- a/gemma/gemma_args.h +++ b/gemma/gemma_args.h @@ -22,8 +22,10 @@ #include #include +#include #include +#include "compression/types.h" #include "gemma/configs.h" #include "io/io.h" // Path #include "util/args.h" // IWYU pragma: export diff --git a/gemma/kv_cache.cc b/gemma/kv_cache.cc index 2fe6885..49276f8 100644 --- a/gemma/kv_cache.cc +++ b/gemma/kv_cache.cc @@ -16,8 +16,12 @@ #include "gemma/kv_cache.h" #include + +#include +#include #include +#include "compression/types.h" #include "gemma/configs.h" #include "gemma/gemma_args.h" #include "util/mat.h" // ZeroInit diff --git a/gemma/kv_cache.h b/gemma/kv_cache.h index bad66fa..fe6a1ff 100644 --- a/gemma/kv_cache.h +++ b/gemma/kv_cache.h @@ -19,12 +19,14 @@ #include #include +#include #include #include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // InferenceArgs #include "util/basics.h" // BF16 #include "util/mat.h" +#include "hwy/base.h" namespace gcpp { diff --git a/gemma/kv_cache_test.cc b/gemma/kv_cache_test.cc index 157b3d9..7b7bed2 100644 --- a/gemma/kv_cache_test.cc +++ b/gemma/kv_cache_test.cc @@ -35,8 +35,13 @@ TEST(KVCacheTest, KVCacheToPtrs) { std::vector ptrs = ToKVCachePtrs({caches.data(), caches.size()}); ASSERT_EQ(ptrs.size(), 2); - EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0)); - EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0)); + if (caches[0].IsTiled()) { + EXPECT_EQ(ptrs[0].cache, &caches[0]); + EXPECT_EQ(ptrs[1].cache, &caches[1]); + } else { + EXPECT_EQ(ptrs[0].kv_cache.Row(0), caches[0].kv_cache.Row(0)); + EXPECT_EQ(ptrs[1].kv_cache.Row(0), caches[1].kv_cache.Row(0)); + } } } // namespace diff --git a/util/mat.h b/util/mat.h index 83d03b1..0830046 100644 --- a/util/mat.h +++ b/util/mat.h @@ -469,6 +469,38 @@ decltype(auto) CallUpcastedKV(const MatPtr* base, const Func& func, } } +// Calls 'func' with a span of MatPtrT for all elements in `base`. +// T is dynamic type, read from base. It is assumed that all elements in `base` +// have the same type. +template +decltype(auto) CallUpcastedKVs(hwy::Span base, const Func& func, + Args&&... args) { + Type type = base[0].GetType(); + for ([[maybe_unused]] auto&& mat : base) { + HWY_DASSERT(mat.GetType() == type); + } + auto convert_to_matptr_t = [&base]() { + std::vector> matptrs; + matptrs.reserve(base.size()); + for (auto&& mat : base) { + matptrs.emplace_back(mat); + } + return matptrs; + }; + if (type == Type::kF32) { + auto matptrs = convert_to_matptr_t.template operator()(); + hwy::Span> matptrs_span(matptrs.data(), + matptrs.size()); + return func(matptrs_span, std::forward(args)...); + } else if (type == Type::kBF16) { + auto matptrs = convert_to_matptr_t.template operator()(); + hwy::Span> matptrs_span(matptrs.data(), matptrs.size()); + return func(matptrs_span, std::forward(args)...); + } else { + HWY_ABORT("Unhandled type %s.", TypeName(type)); + } +} + void CopyMat(const MatPtr& from, MatPtr& to); void ZeroInit(MatPtr& mat);