diff --git a/BUILD.bazel b/BUILD.bazel index ab04fd3..84de4c4 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -112,6 +112,7 @@ cc_library( ":threading", ":topology", ":zones", + "//io", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:profiler", @@ -555,6 +556,7 @@ cc_library( "gemma/attention.cc", "gemma/flash_attention.cc", "gemma/gemma.cc", + "gemma/tensor_stats.cc", "gemma/vit.cc", ], hdrs = [ @@ -563,6 +565,7 @@ cc_library( "gemma/flash_attention.h", "gemma/flash_structs.h", "gemma/gemma.h", + "gemma/tensor_stats.h", "gemma/vit.h", ], exec_properties = { @@ -597,6 +600,7 @@ cc_library( "@highway//:hwy", "@highway//:nanobenchmark", # timer "@highway//:profiler", + "@highway//:stats", "@highway//:thread_pool", "@highway//hwy/contrib/sort:vqsort", ] + diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index 0fb43d7..7d042da 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -1,4 +1,4 @@ -# Weight compression and analysis. +# Compressed tensor types. load("@rules_cc//cc:cc_library.bzl", "cc_library") load("@rules_cc//cc:cc_test.bzl", "cc_test") @@ -208,19 +208,3 @@ cc_test( "@highway//:hwy_test_util", ], ) - -# For internal experimentation -cc_library( - name = "analyze", - textual_hdrs = ["analyze.h"], - deps = [ - ":int", - ":nuq", - ":sfp", - ":types", - "@highway//:hwy", - "@highway//:stats", - "@highway//:thread_pool", - "@highway//hwy/contrib/sort:vqsort", - ], -) diff --git a/compression/analyze.h b/compression/analyze.h deleted file mode 100644 index 7d41633..0000000 --- a/compression/analyze.h +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Normal include guard to placate lint. -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ - -#include -#include -#include -#include // memcpy - -#include // std::signbit -#include // std::abs -#include - -#include "compression/types.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" -#include "hwy/stats.h" - -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ - -// Actual per-target include guard. -#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE) -#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE -#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE -#else -#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE -#endif - -#include "compression/nuq-inl.h" -#include "compression/sfp-inl.h" -#include "hwy/contrib/sort/vqsort-inl.h" -#include "hwy/highway.h" - -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -class PerThread { - public: - void NotifyGroup(const float* group) { - constexpr size_t kGroupSize = NuqStream::kGroupSize; - hwy::Stats s_group; - for (size_t i = 0; i < kGroupSize; ++i) { - // Skip zero so we can see the lowest actual magnitude - if (group[i] == 0.0f || group[i] == -0.0f) continue; - s_all_.Notify(group[i]); - s_group.Notify(group[i]); - - num_tiny_ += std::abs(group[i]) < 1e-3f; - - // b_magn100_.Notify(group[i] * 40.0f + 20.0f); - const uint32_t binary32 = - hwy::BitCastScalar(std::abs(group[i])); - - // const int32_t exp = (binary32 >> 23) - 127; - b_exp256_.Notify(binary32 >> 23); - const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4); - b_m4_.Notify(m4); - } - s_group_ranges_.Notify(s_group.Max() - s_group.Min()); - s_group_mins_.Notify(s_group.Min()); - s_group_maxs_.Notify(s_group.Max()); - - float desc[kGroupSize]; - memcpy(desc, group, kGroupSize * sizeof(group[0])); - hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending()); - - // Find largest |max/min| (dynamic range) - float max_ratio = 0.0f; - for (size_t i = 0; i < kGroupSize; ++i) { - if (desc[i] != 0.0f && desc[i] != -0.0f) { - max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i])); - } - } - s_group_max_vs_min_.Notify(max_ratio); - - // Relative errors - float diffs[kGroupSize]; - for (size_t i = 0; i < kGroupSize - 1; ++i) { - // was in descending order. Avoid div by 0. Ignore sign changes. - diffs[i] = std::abs(desc[i]) < 1e-5 - ? 0 - : std::abs((desc[i] - desc[i + 1]) / desc[i]); - } - hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending()); - s_cut15_.Notify(diffs[15]); - } - - void Assimilate(const PerThread& other) { - num_tiny_ += other.num_tiny_; - s_all_.Assimilate(other.s_all_); - s_group_ranges_.Assimilate(other.s_group_ranges_); - s_group_mins_.Assimilate(other.s_group_mins_); - s_group_maxs_.Assimilate(other.s_group_maxs_); - s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_); - s_erange_.Assimilate(other.s_erange_); - s_km_1_.Assimilate(other.s_km_1_); - s_km_2_.Assimilate(other.s_km_2_); - s_cut15_.Assimilate(other.s_cut15_); - b_magn100_.Assimilate(other.b_magn100_); - b_exp256_.Assimilate(other.b_exp256_); - b_m4_.Assimilate(other.b_m4_); - } - - void PrintAll() { - const int skip = hwy::Stats::kNoGeomean; - fprintf(stderr, "num tiny %zu\n", num_tiny_); - fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str()); - fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str()); - fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str()); - fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str()); - fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str()); - fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str()); - fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str()); - fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str()); - fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str()); - - // b_magn100_.Print("magn100"); - // b_exp256_.Print("exp"); - // b_m4_.Print("mantissa bits4"); - - fprintf(stderr, "\n"); - } - - private: - size_t num_tiny_ = 0; - hwy::Stats s_all_; - hwy::Stats s_group_ranges_; - hwy::Stats s_group_mins_; - hwy::Stats s_group_maxs_; - hwy::Stats s_group_max_vs_min_; - hwy::Stats s_erange_; - hwy::Stats s_km_1_; - hwy::Stats s_km_2_; - hwy::Stats s_cut15_; - hwy::Bins<100> b_magn100_; - hwy::Bins<256> b_exp256_; - hwy::Bins<16> b_m4_; - uint8_t padding_[64]; // prevent false sharing -}; - -class PerLayer { - public: - void NotifyGroup(const float* group) { - for (size_t i = 0; i < NuqStream::kGroupSize; ++i) { - s_layer_.Notify(group[i]); - } - } - - void UpdateOutliers(const float* layer, size_t weights_per_layer) { - const float layer_mean = s_layer_.Mean(); - const float layer_sd = s_layer_.StandardDeviation(); - for (size_t i = 0; i < weights_per_layer; ++i) { - num_outliers_ += - std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd; - } - } - - const hwy::Stats& GetStats() const { return s_layer_; } - size_t Outliers() const { return num_outliers_; } - - private: - hwy::Stats s_layer_; - size_t num_outliers_ = 0; - uint8_t padding[64]; // prevent false sharing -}; - -static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers, - size_t weights_per_layer, - hwy::ThreadPool& pool) { - std::vector tls; - std::vector per_layer(layers); - const auto init = [&](size_t num_threads) { - tls.resize(num_threads); - return true; - }; - - pool.Run(0, static_cast(layers), init, - [&](uint32_t idx_layer, size_t idx_thread) { - PerThread& self = tls[idx_thread]; - const float* layer = &mat[idx_layer * weights_per_layer]; - // For each whole group in the layer - for (size_t group_start = 0; - group_start + NuqStream::kGroupSize <= weights_per_layer; - group_start += NuqStream::kGroupSize) { - const float* group = layer + group_start; - per_layer[idx_layer].NotifyGroup(group); - self.NotifyGroup(group); - } - - per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer); - }); - - const int skip = hwy::Stats::kNoGeomean; - fprintf(stderr, "\n------------%s\n", caption); - - for (size_t i = 1; i < pool.NumWorkers(); ++i) { - tls[0].Assimilate(tls[i]); - } - tls[0].PrintAll(); - - hwy::Stats s_layer_ranges; - hwy::Stats s_layer_outliers; - for (size_t i = 0; i < layers; ++i) { - fprintf(stderr, " %02zu %s\n", i, - per_layer[i].GetStats().ToString(skip).c_str()); - const float range = - per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min(); - s_layer_ranges.Notify(range); - s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) / - weights_per_layer); - } - fprintf(stderr, "layer outliers%% %s\n", - s_layer_outliers.ToString(skip).c_str()); - fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str()); -} - -// NOLINTNEXTLINE(google-readability-namespace-comments) -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_ diff --git a/gemma/activations.h b/gemma/activations.h index 021b2fd..704c3ee 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -23,12 +23,13 @@ #include #include -#include "gemma/configs.h" // ModelConfig +#include "gemma/configs.h" // ModelConfig #include "gemma/gemma_args.h" // AttentionImpl #include "gemma/kv_cache.h" -#include "ops/ops.h" // CreateInvTimescale -#include "util/basics.h" // BF16 -#include "util/mat.h" // MatStorageT +#include "gemma/tensor_stats.h" +#include "ops/ops.h" // CreateInvTimescale +#include "util/basics.h" // BF16 +#include "util/mat.h" // MatStorageT #include "util/threading_context.h" namespace gcpp { @@ -268,6 +269,14 @@ struct Activations { ffw_out( MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)), + max_workers(ctx.pools.MaxWorkers()), + s_ffw_in(config.num_layers, max_workers), + s_ffw_hidden(config.num_layers, max_workers), + s_ffw_out(config.num_layers, max_workers), + + s_w_gating_einsum_w1(config.num_layers, max_workers), + s_w_gating_einsum_w2(config.num_layers, max_workers), + s_w_linear_w(config.num_layers, max_workers), attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, runtime_config.attention_impl, ctx.allocator, @@ -288,6 +297,12 @@ struct Activations { // Note that BindC on any MatMul output considerably slows down Prefill. } + ~Activations() { + s_ffw_in.ReduceAndPrint("ffw_in"); + s_ffw_hidden.ReduceAndPrint("ffw_hidden"); + s_ffw_out.ReduceAndPrint("ffw_out"); + } + // Negligible CPU time. void SetBatchSize(size_t batch_size) { x.OverrideRows(batch_size); @@ -319,6 +334,15 @@ struct Activations { MatStorageT C2; MatStorageT ffw_out; + const size_t max_workers; + TensorStats s_ffw_in; + TensorStats s_ffw_hidden; // after Activation+gating + TensorStats s_ffw_out; + + TensorStats s_w_gating_einsum_w1; + TensorStats s_w_gating_einsum_w2; + TensorStats s_w_linear_w; + AttentionImpl attention_impl; AttentionActivations attention_storage; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index 2c55ef4..0cd364a 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -20,6 +20,7 @@ #include "gemma/activations.h" #include "gemma/configs.h" +#include "gemma/tensor_stats.h" #include "gemma/weights.h" #include "ops/matmul.h" #include "util/mat.h" @@ -158,6 +159,9 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit. + activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out, + env.ctx); + #if GEMMA_FUSED_FFN const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c, StridedViewBF C2, size_t worker) { @@ -179,8 +183,12 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer, env.ctx); #endif + activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx); + // Hidden layer -> output layer. CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out); + + activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx); } // Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 1af520a..055caae 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -21,11 +21,13 @@ #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS -#include "util/zones.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS +#include "gemma/tensor_stats.h" +#include "util/zones.h" + // Compiles this file for multiple architectures via "foreach_target.h", to // which we pass the filename via macro 'argument'. // clang-format off @@ -568,12 +570,26 @@ static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config, config, runtime_config, qbatch, update_pos, non_eos); } +void SetWeightStats(const LayerWeightsPtrs& layer, Activations& a, + ThreadingContext& ctx) { + const size_t layer_idx = layer.layer_idx; + a.s_w_gating_einsum_w1.Notify(layer_idx, layer.gating_einsum_w1, ctx, + kTensorStatsIsWeight); + a.s_w_gating_einsum_w2.Notify(layer_idx, layer.gating_einsum_w2, ctx, + kTensorStatsIsWeight); + a.s_w_linear_w.Notify(layer_idx, layer.linear_w, ctx, kTensorStatsIsWeight); +} + // Decode: generates one continuation token for each query in `qbatch`. static void GenerateT(const ModelConfig& config, const RuntimeConfig& runtime_config, const AesCtrEngine& engine, const WeightsPtrs& weights, Activations& activations, QBatch& qbatch, MatMulEnv& env, TimingInfo& timing_info) { + for (const LayerWeightsPtrs& layer : weights.c_layers) { + SetWeightStats(layer, activations, env.ctx); + } + const size_t max_gen_steps = PrefillTBatchOrQBatch( config, runtime_config, weights, activations, qbatch, env, timing_info); diff --git a/gemma/tensor_info.cc b/gemma/tensor_info.cc index 05f829b..2810307 100644 --- a/gemma/tensor_info.cc +++ b/gemma/tensor_info.cc @@ -1,3 +1,18 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #include "gemma/tensor_info.h" #include diff --git a/gemma/tensor_info.h b/gemma/tensor_info.h index 6becb29..60decf8 100644 --- a/gemma/tensor_info.h +++ b/gemma/tensor_info.h @@ -1,3 +1,18 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_ diff --git a/gemma/tensor_stats.cc b/gemma/tensor_stats.cc new file mode 100644 index 0000000..53203b6 --- /dev/null +++ b/gemma/tensor_stats.cc @@ -0,0 +1,205 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "gemma/tensor_stats.h" + +#if GCPP_TENSOR_STATS +#include +#include +#include + +#include +#include +#include + +#include "io/io.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "util/zones.h" +#include "hwy/profiler.h" // StringTable + +// Compiles this file for multiple architectures via "foreach_target.h", to +// which we pass the filename via macro 'argument'. +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "gemma/tensor_stats.cc" // NOLINT +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "ops/dot-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { + +float Correlation(const float* x, size_t num) { + double sum = 0.0; + for (size_t i = 0; i < num; ++i) { + sum += x[i]; + } + const double mean = sum / static_cast(num); + + double numerator = 0.0; + double sum_sq_current = 0.0; + double sum_sq_next = 0.0; + + for (size_t i = 0; i < num - 1; ++i) { + const double diff_current = static_cast(x[i]) - mean; + const double diff_next = static_cast(x[i + 1]) - mean; + + numerator += diff_current * diff_next; + sum_sq_current += diff_current * diff_current; + sum_sq_next += diff_next * diff_next; + } + + if (sum_sq_current == 0.0 || sum_sq_next == 0.0) return 0.0f; + const double denominator = std::sqrt(sum_sq_current * sum_sq_next); + const float corr = static_cast(numerator / denominator); + HWY_DASSERT(-1.0f <= corr && corr <= 1.0f); + return corr; +} + +// Only write tensor data the first time it is encountered per layer. This is +// a concurrent string+layer -> flag map which avoids std::mutex (incompatible +// with fibers). We use a string table to index into per-layer atomic flags. +static bool ShouldWrite(const char* name, size_t layer_idx) { + constexpr size_t kMaxNames = 128; + constexpr size_t kMaxLayers = 128; + HWY_DASSERT(layer_idx < kMaxLayers); + static hwy::StringTable s_table; + const size_t name_idx = s_table.Add(name); + static std::atomic_flag flags[kMaxNames * kMaxLayers] = {}; + return !flags[name_idx * kMaxLayers + layer_idx].test_and_set( + std::memory_order_acq_rel); +} + +std::unique_ptr MaybeOpenFile(size_t layer_idx, const MatPtr& type_erased, + const Path& tensor_output) { + if (tensor_output.Empty()) return nullptr; + if (!ShouldWrite(type_erased.Name(), layer_idx)) return nullptr; + char path[1024]; + snprintf(path, sizeof(path), "%s/%s_L%02zu_%zux%zu_%s.bin", + tensor_output.path.c_str(), type_erased.Name(), layer_idx, + type_erased.Rows(), type_erased.Cols(), + TypeName(type_erased.GetType())); + return OpenFileOrAbort(Path(path), "wb"); +} + +void MaybeWriteRow(const std::unique_ptr& file, const MatPtr& type_erased, + size_t row_idx) { + if (!file) return; + const size_t bytes_per_row = type_erased.Cols() * type_erased.ElementBytes(); + file->Write(type_erased.RowBytes(row_idx), bytes_per_row, + bytes_per_row * row_idx); +} + +// First dispatch to the type, then parallel over rows, then vectorized +// decompress and Notify for each value. +void UpdateStatsT(TensorStats& stats, size_t layer_idx, + const MatPtr& type_erased, ThreadingContext& ctx, int flags, + size_t cluster_idx, Parallelism parallelism) { + std::unique_ptr file = + MaybeOpenFile(layer_idx, type_erased, ctx.tensor_output); + + if ((flags & kTensorStatsIsWeight) && layer_idx != 0) { + // Still compute stats, but remember not to print them. + stats.Get(layer_idx, 0).DoNotPrint(); + } + + CallUpcasted(&type_erased, [&](const auto* mat) { + const size_t cols = mat->Cols(); + + ParallelFor( + parallelism, mat->Rows(), ctx, cluster_idx, Callers::kTensorStats, + [&](size_t row_idx, size_t global_idx) { + GCPP_ZONE(ctx, global_idx, Zones::kGenStats); + + auto* HWY_RESTRICT row = mat->Row(row_idx); + MaybeWriteRow(file, type_erased, row_idx); + + using Packed = hwy::RemoveCvRef; + PackedSpan packed(const_cast(row), cols); + + TensorStatsAccumulator& my_stats = stats.Get(layer_idx, global_idx); + my_stats.NotifyCond(ConditionNumber(row, cols)); + + namespace hn = hwy::HWY_NAMESPACE; + hn::ScalableTag df; + using VF = hn::Vec; + HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df); + HWY_ALIGN float buf[2 * hn::MaxLanes(df)]; + + size_t packed_ofs = 0; + if (cols >= 2 * NF) { + for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) { + VF v0, v1; + Decompress2(df, packed, packed_ofs, v0, v1); + hn::Store(v0, df, buf); + hn::Store(v1, df, buf + NF); + const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1)); + const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1)); + const float min = hn::ReduceMin(df, min_mag); + if (min != 0.0f) { // Avoid division by zero. + my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag)); + } + + for (size_t i = 0; i < 2 * NF; ++i) { + my_stats.Notify(buf[i], row_idx, packed_ofs + i); + } + my_stats.NotifyCorr(Correlation(buf, 2 * NF)); + } + } + + // Zero to two vectors remaining. + for (; packed_ofs < cols; packed_ofs += NF) { + const size_t remaining = HWY_MIN(NF, cols - packed_ofs); + DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining); + // Skip NotifyGroup for this partial group. + for (size_t i = 0; i < remaining; ++i) { + my_stats.Notify(buf[i], row_idx, packed_ofs + i); + } + my_stats.NotifyCorr(Correlation(buf, remaining)); + } + }); + }); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE +namespace gcpp { +HWY_EXPORT(UpdateStatsT); + +// Must reside in .cc file so that we can #include compress-inl.h. +void TensorStats::Notify(size_t layer_idx, const MatPtr& type_erased, + ThreadingContext& ctx, int flags, size_t cluster_idx, + Parallelism parallelism) { + // Ignore empty tensors. + if (type_erased.GetType() == Type::kUnknown || type_erased.Cols() == 0) { + return; + } + HWY_DYNAMIC_DISPATCH(UpdateStatsT)(*this, layer_idx, type_erased, ctx, flags, + cluster_idx, parallelism); +} + +} // namespace gcpp +#endif // HWY_ONCE + +#endif // GCPP_TENSOR_STATS diff --git a/gemma/tensor_stats.h b/gemma/tensor_stats.h new file mode 100644 index 0000000..6975ab5 --- /dev/null +++ b/gemma/tensor_stats.h @@ -0,0 +1,347 @@ +// Copyright 2025 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_ + +#include +#include +#include + +#include "util/basics.h" +#include "hwy/base.h" + +#ifndef GCPP_TENSOR_STATS +#define GCPP_TENSOR_STATS 0 +#endif + +#include "util/mat.h" +#include "util/threading_context.h" + +#if GCPP_TENSOR_STATS +#include +#include + +#include "hwy/stats.h" +#endif // GCPP_TENSOR_STATS + +namespace gcpp { + +// For flags. Used to inhibit printing per-layer stats for weights. +HWY_INLINE_VAR constexpr int kTensorStatsIsWeight = 1; + +#if GCPP_TENSOR_STATS + +HWY_INLINE_VAR constexpr size_t kStatsMaxCols = 8192; + +// Separate summary of the per-layer stats, updated by `TensorStatsAccumulator`. +// We pass per-layer statistics such as the mean value to `hwy::Stats::Notify`` +// to see the distribution of per-layer means. +struct TensorStatsAcrossLayers { + bool IsEmpty() const { return s_frobenius.Count() == 0; } + + void Print() { + const int skip = hwy::Stats::kNoGeomean; + fprintf(stderr, "frob %s\n", s_frobenius.ToString(skip).c_str()); + fprintf(stderr, "cnd.min %s\n", s_cond_min.ToString(skip).c_str()); + fprintf(stderr, "cnd.avg %s\n", s_cond_avg.ToString(skip).c_str()); + fprintf(stderr, "cnd.max %s\n", s_cond_max.ToString(skip).c_str()); + fprintf(stderr, "val.min %s\n", s_val_min.ToString(skip).c_str()); + fprintf(stderr, "val.avg %s\n", s_val_avg.ToString(skip).c_str()); + fprintf(stderr, "val.krt %s\n", s_val_kurt.ToString(skip).c_str()); + fprintf(stderr, "mag.min %s\n", s_mag_min.ToString(skip).c_str()); + fprintf(stderr, "mag.avg %s\n", s_mag_avg.ToString(skip).c_str()); + fprintf(stderr, "mag.max %s\n", s_mag_max.ToString(skip).c_str()); + if (hwy::ScalarAbs(s_corr_avg.Max()) > 0.05f) { + fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str()); + } + fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str()); + fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str()); + fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str()); + fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str()); + fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str()); + if (s_exp_subnormal.Min() != 0.0f) { + fprintf(stderr, "exp.sub %s\n", s_exp_subnormal.ToString(skip).c_str()); + } + if (s_big_cols.Count() != 0) { + fprintf(stderr, "bigCols %s\n", s_big_cols.ToString(skip).c_str()); + const size_t modal_col = b_big_cols.ModalBinIdx(); + const size_t num_outlier_cols = b_big_cols.NumNonzero(); + if (num_outlier_cols > 256) { + fprintf(stderr, "bigCols: all up to %zu (max at %zu: %u layers):\n", + b_big_cols.LastNonzero(), modal_col, b_big_cols.Bin(modal_col)); + } else { + fprintf(stderr, "bigCols (max at %zu: %u layers):\n", modal_col, + b_big_cols.Bin(modal_col)); + for (size_t i = 0; i < kStatsMaxCols; ++i) { + if (b_big_cols.Bin(i) > 2) { + fprintf(stderr, " %3zu: %u\n", i, b_big_cols.Bin(i)); + } + } + } + } + fprintf(stderr, "\n"); + } + + hwy::Stats s_frobenius; + + hwy::Stats s_cond_min; + hwy::Stats s_cond_avg; + hwy::Stats s_cond_max; + + hwy::Stats s_val_min; + hwy::Stats s_val_avg; + hwy::Stats s_val_kurt; + + hwy::Stats s_mag_min; + hwy::Stats s_mag_avg; + hwy::Stats s_mag_max; + + hwy::Stats s_corr_avg; + hwy::Stats s_corr_max; + + hwy::Stats s_range_avg; + + hwy::Stats s_exp_min; + hwy::Stats s_exp_max; + hwy::Stats s_exp_mode; + hwy::Stats s_exp_subnormal; + + hwy::Stats s_big_cols; // total number of outlier cols + hwy::Bins b_big_cols; // # layers with outlier per col +}; + +// Per-thread and layer. +class TensorStatsAccumulator { + public: + void Notify(float val, size_t row_idx, size_t col_idx) { + const double dval = static_cast(val); + sum_sq_ += dval * dval; + + s_val_.Notify(val); + const float mag = hwy::ScalarAbs(val); + + if (HWY_UNLIKELY(mag >= 64.0f)) { + if (row_idx < kMaxBatchSize) b_big_row_.Notify(row_idx); + if (col_idx < kStatsMaxCols) b_big_col_.Notify(col_idx); + } + + // Skip zero so we can see the lowest actual magnitude + if (mag != 0.0f && mag != -0.0f) s_mag_.Notify(mag); + + const uint32_t binary32 = hwy::BitCastScalar(mag); + // Use biased exponent because Bins wants unsigned values. + const uint32_t biased_exp = binary32 >> 23; + HWY_DASSERT(biased_exp < 256); // already cleared sign bit + b_exp256_.Notify(biased_exp); + } + + void DoNotPrint() { skip_.fetch_or(1); } + bool ShouldPrint() const { return skip_.load() == 0; } + + // Vector code computed the min/max of a group (= two vectors); this is + // faster than doing it in `Notify`. + void NotifyGroup(float min, float max) { + s_group_min_.Notify(min); + s_group_max_.Notify(max); + // Caller ensures min != 0. + s_group_range_.Notify(max / min); + } + + void NotifyCorr(float corr) { s_corr_.Notify(corr); } + void NotifyCond(double cond) { s_cond_.Notify(cond); } + + void Assimilate(const TensorStatsAccumulator& other) { + skip_.fetch_or(other.skip_.load()); + + sum_sq_ += other.sum_sq_; + b_exp256_.Assimilate(other.b_exp256_); + b_big_row_.Assimilate(other.b_big_row_); + b_big_col_.Assimilate(other.b_big_col_); + s_val_.Assimilate(other.s_val_); + s_mag_.Assimilate(other.s_mag_); + s_corr_.Assimilate(other.s_corr_); + s_group_min_.Assimilate(other.s_group_min_); + s_group_max_.Assimilate(other.s_group_max_); + s_group_range_.Assimilate(other.s_group_range_); + } + + // Called on the per-layer representative after reducing across threads. + void NotifyAcrossLayer(TensorStatsAcrossLayers& s) { + s.s_frobenius.Notify(std::sqrt(sum_sq_)); + + s.s_cond_min.Notify(s_cond_.Min()); + s.s_cond_avg.Notify(s_cond_.Mean()); + s.s_cond_max.Notify(s_cond_.Max()); + + s.s_val_min.Notify(s_val_.Min()); + s.s_val_avg.Notify(s_val_.Mean()); + s.s_val_kurt.Notify(s_val_.Kurtosis()); + + s.s_mag_min.Notify(s_mag_.Min()); + s.s_mag_avg.Notify(s_mag_.Mean()); + s.s_mag_max.Notify(s_mag_.Max()); + + s.s_corr_avg.Notify(s_corr_.Mean()); + s.s_corr_max.Notify(s_corr_.Max()); + + s.s_range_avg.Notify(s_group_range_.Mean()); + + const uint32_t subnormals = b_exp256_.Bin(0); + // Prevent subnormals from hiding the min exponent. + b_exp256_.ResetBin(0); + s.s_exp_min.Notify(b_exp256_.FirstNonzero()); + s.s_exp_max.Notify(b_exp256_.LastNonzero()); + s.s_exp_mode.Notify(b_exp256_.ModalBinIdx()); + s.s_exp_subnormal.Notify(subnormals); + + const uint32_t num_outliers = b_big_col_.NumNonzero(); + if (num_outliers != 0) { + s.s_big_cols.Notify(num_outliers); + // For each col, count the number of layers that have an outlier there. + for (size_t i = 0; i < kStatsMaxCols; ++i) { + if (b_big_col_.Bin(i) != 0) s.b_big_cols.Notify(i); + } + } + } + + bool IsEmpty() const { return s_val_.Count() == 0; } + + void PrintAll() { + fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_)); + const int skip = hwy::Stats::kNoGeomean; + fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str()); + fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str()); + fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str()); + fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str()); + fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str()); + fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str()); + fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str()); + b_exp256_.Print("exp"); + PrintBinRanges(b_big_row_, "big row"); + PrintBinRanges(b_big_col_, "big col"); + fprintf(stderr, "\n"); + } + + private: + template + void PrintBinRanges(const hwy::Bins& b, const char* name) { + uint64_t total = 0; + for (size_t i = 0; i < N; ++i) { + total += b.Bin(i); + } + if (total == 0) return; + + // If all bins are at least 10% of a uniform distribution, print the range + // to vastly reduce the log size. + const size_t min = HWY_MAX(1, total / (N * 10)); + size_t last = 0; + for (; last < N; ++last) { + if (b.Bin(last) < min) break; + } + if (last >= N / 2) { + // Also require all subsequent bins to be zero, otherwise we should + // print the outlier bins. + bool all_zero = true; + for (size_t i = last + 1; i < N; ++i) { + if (b.Bin(last) != 0) { + all_zero = false; + break; + } + } + if (all_zero) { + fprintf(stderr, "%s: uniform up to %zu\n", name, last); + return; + } + } + + b.Print(name, /*skip_zero=*/true); + } + + double sum_sq_ = 0.0; // for Frobenius norm + hwy::Bins<256> b_exp256_; // exponent + hwy::Bins b_big_row_; + hwy::Bins b_big_col_; + hwy::Stats s_val_; + hwy::Stats s_mag_; + hwy::Stats s_cond_; // condition number + hwy::Stats s_corr_; // lag-1 autocorrelation + hwy::Stats s_group_min_; + hwy::Stats s_group_max_; + hwy::Stats s_group_range_; + std::atomic skip_{0}; +}; + +class TensorStats { + public: + TensorStats(size_t num_layers, size_t max_workers) + : num_layers_(num_layers), + max_workers_(max_workers), + acc_(num_layers * max_workers) {} + + // Parallelized across rows. If `ctx.tensor_output` is not empty, writes + // tensor data to disk for offline analysis, once per tensor and layer. + void Notify(size_t layer_idx, const MatPtr& type_erased, + ThreadingContext& ctx, int flags = 0, size_t cluster_idx = 0, + Parallelism parallelism = Parallelism::kFlat); + + // For use by `UpdateStatsT`. + TensorStatsAccumulator& Get(size_t layer_idx, size_t global_idx) { + const size_t idx = layer_idx * max_workers_ + global_idx; + HWY_DASSERT(idx < acc_.size()); + return acc_[idx]; + } + + void ReduceAndPrint(const char* prefix) { + for (size_t layer_idx = 0; layer_idx < num_layers_; ++layer_idx) { + TensorStatsAccumulator& per_layer = Get(layer_idx, 0); + for (size_t global_idx = 1; global_idx < max_workers_; ++global_idx) { + per_layer.Assimilate(Get(layer_idx, global_idx)); + } + if (per_layer.IsEmpty()) continue; + per_layer.NotifyAcrossLayer(across_layers_); + if (per_layer.ShouldPrint()) { + fprintf(stderr, "-------------------- %s %zu\n", prefix, layer_idx); + per_layer.PrintAll(); + } + } + + if (!across_layers_.IsEmpty()) { + fprintf(stderr, "================= across layers %s\n", prefix); + across_layers_.Print(); + } + } + + private: + size_t num_layers_; + size_t max_workers_; + std::vector acc_; + TensorStatsAcrossLayers across_layers_; +}; + +#else // GCPP_TENSOR_STATS + +class TensorStats { + public: + TensorStats(size_t, size_t) {} + void Notify(size_t, const MatPtr&, ThreadingContext&, int = 0, size_t = 0, + Parallelism = Parallelism::kFlat) {} + void ReduceAndPrint(const char*) {} +}; + +#endif // GCPP_TENSOR_STATS +} // namespace gcpp + +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_ diff --git a/gemma/weights.h b/gemma/weights.h index 3661869..4476e22 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -96,7 +96,8 @@ struct LayerWeightsPtrs { // other values for purposes of the KV cache. LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config, const TensorInfoRegistry& tensors) - : finder_(LayerSuffix(layer_idx), tensors), + : layer_idx(layer_idx), + finder_(LayerSuffix(layer_idx), tensors), qkv_einsum_w(finder_("qkv_ein")), qkv_einsum_w1(finder_("qkv1_w")), qkv_einsum_w2(finder_("qkv2_w")), @@ -135,6 +136,7 @@ struct LayerWeightsPtrs { } ~LayerWeightsPtrs() = default; + const size_t layer_idx; const MatFinder finder_; // Files either have qkv_einsum_w with 2 stacked matrices or separate diff --git a/io/io.cc b/io/io.cc index 8114276..2f479b2 100644 --- a/io/io.cc +++ b/io/io.cc @@ -196,9 +196,9 @@ std::unique_ptr OpenFileOrNull(const Path& filename, const char* mode) { namespace gcpp { std::unique_ptr OpenFileOrAbort(const Path& filename, const char* mode) { - std::unique_ptr file = OpenFileOrNull(filename, "r"); + std::unique_ptr file = OpenFileOrNull(filename, mode); if (!file) { - HWY_ABORT("Failed to open %s", filename.path.c_str()); + HWY_ABORT("Failed to open %s, errno %d", filename.path.c_str(), errno); } return file; } diff --git a/util/threading_context.cc b/util/threading_context.cc index d3aa74f..4a3b927 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -97,7 +97,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args) BoundedSlice(args.skip_lps, args.max_lps)), cache_info(topology), allocator(topology, cache_info, args.bind != Tristate::kFalse), - pools(topology, allocator, args.max_threads, args.pin) { + pools(topology, allocator, args.max_threads, args.pin), + tensor_output(args.tensor_output) { PROFILER_ZONE("Startup.ThreadingContext autotune"); TunePools(hwy::PoolWaitMode::kSpin, *this); // kBlock is the default, hence set/tune it last. diff --git a/util/threading_context.h b/util/threading_context.h index 7a5c3f5..07c8089 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -23,6 +23,7 @@ #include // IWYU pragma: begin_exports +#include "io/io.h" // Path #include "util/allocator.h" #include "util/args.h" #include "util/basics.h" // Tristate @@ -55,6 +56,8 @@ class ThreadingArgs : public ArgsBase { Tristate pin; // pin threads? Tristate spin; // use spin waits? + Path tensor_output; // empty, or directory for tensor output + template void ForEach(const Visitor& visitor) { // These can be used to partition CPU packages/sockets and their @@ -85,6 +88,9 @@ class ThreadingArgs : public ArgsBase { visitor(bind, "bind", Tristate::kDefault, "Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2); + + visitor(tensor_output, "tensor_output", Path(), + "Empty, or directory for tensor output.", 2); } }; @@ -124,6 +130,8 @@ struct ThreadingContext { // Per-package/cluster/within cluster pools of threads, matching `topology`. NestedPools pools; + + Path tensor_output; // used by `TensorStats::Notify`. }; #define GCPP_ZONE(ctx, global_idx, zone_enum) \ diff --git a/util/zones.cc b/util/zones.cc index a474311..6480b96 100644 --- a/util/zones.cc +++ b/util/zones.cc @@ -51,6 +51,8 @@ const char* ZoneName(Zones zone) { return "Gen.SampleTop1"; case Zones::kGenSampleTopK: return "Gen.SampleTopK"; + case Zones::kGenStats: + return "Gen.Stats"; case Zones::kMMDecompressA: return "MM.DecompressA"; case Zones::kMMDispatch: @@ -163,6 +165,8 @@ const char* CallerName(Callers caller) { return "ReadBatches"; case Callers::kSampleAndStream: return "SampleAndStream"; + case Callers::kTensorStats: + return "TensorStats"; case Callers::kTest: // only for unit tests. return "Test-only!"; case Callers::kTunePool: diff --git a/util/zones.h b/util/zones.h index 5624e24..f324086 100644 --- a/util/zones.h +++ b/util/zones.h @@ -31,6 +31,7 @@ enum class Zones { // Keep sorted kGenFFW, kGenSampleTop1, kGenSampleTopK, + kGenStats, kMMDecompressA, kMMDispatch, kMMMatMul, @@ -96,6 +97,7 @@ enum class Callers { // Keep sorted kReadAllToBF16, kReadBatches, kSampleAndStream, + kTensorStats, kTest, // only for unit tests. kTunePool, kVitDotSoftmax1,