Add tensor stats and output

tensor_info: add missing header
io: fix mode
weights.h: add layer_idx to LayerWeightsPtrs
PiperOrigin-RevId: 843531051
This commit is contained in:
Jan Wassenberg 2025-12-11 22:51:50 -08:00 committed by Copybara-Service
parent bfc0dfcfca
commit 73c3627b67
16 changed files with 661 additions and 264 deletions

View File

@ -112,6 +112,7 @@ cc_library(
":threading",
":topology",
":zones",
"//io",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:profiler",
@ -555,6 +556,7 @@ cc_library(
"gemma/attention.cc",
"gemma/flash_attention.cc",
"gemma/gemma.cc",
"gemma/tensor_stats.cc",
"gemma/vit.cc",
],
hdrs = [
@ -563,6 +565,7 @@ cc_library(
"gemma/flash_attention.h",
"gemma/flash_structs.h",
"gemma/gemma.h",
"gemma/tensor_stats.h",
"gemma/vit.h",
],
exec_properties = {
@ -597,6 +600,7 @@ cc_library(
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
] +

View File

@ -1,4 +1,4 @@
# Weight compression and analysis.
# Compressed tensor types.
load("@rules_cc//cc:cc_library.bzl", "cc_library")
load("@rules_cc//cc:cc_test.bzl", "cc_test")
@ -208,19 +208,3 @@ cc_test(
"@highway//:hwy_test_util",
],
)
# For internal experimentation
cc_library(
name = "analyze",
textual_hdrs = ["analyze.h"],
deps = [
":int",
":nuq",
":sfp",
":types",
"@highway//:hwy",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)

View File

@ -1,238 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard to placate lint.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <string.h> // memcpy
#include <cmath> // std::signbit
#include <cstdlib> // std::abs
#include <vector>
#include "compression/types.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/stats.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
#endif
#include "compression/nuq-inl.h"
#include "compression/sfp-inl.h"
#include "hwy/contrib/sort/vqsort-inl.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
class PerThread {
public:
void NotifyGroup(const float* group) {
constexpr size_t kGroupSize = NuqStream::kGroupSize;
hwy::Stats s_group;
for (size_t i = 0; i < kGroupSize; ++i) {
// Skip zero so we can see the lowest actual magnitude
if (group[i] == 0.0f || group[i] == -0.0f) continue;
s_all_.Notify(group[i]);
s_group.Notify(group[i]);
num_tiny_ += std::abs(group[i]) < 1e-3f;
// b_magn100_.Notify(group[i] * 40.0f + 20.0f);
const uint32_t binary32 =
hwy::BitCastScalar<uint32_t>(std::abs(group[i]));
// const int32_t exp = (binary32 >> 23) - 127;
b_exp256_.Notify(binary32 >> 23);
const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4);
b_m4_.Notify(m4);
}
s_group_ranges_.Notify(s_group.Max() - s_group.Min());
s_group_mins_.Notify(s_group.Min());
s_group_maxs_.Notify(s_group.Max());
float desc[kGroupSize];
memcpy(desc, group, kGroupSize * sizeof(group[0]));
hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending());
// Find largest |max/min| (dynamic range)
float max_ratio = 0.0f;
for (size_t i = 0; i < kGroupSize; ++i) {
if (desc[i] != 0.0f && desc[i] != -0.0f) {
max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i]));
}
}
s_group_max_vs_min_.Notify(max_ratio);
// Relative errors
float diffs[kGroupSize];
for (size_t i = 0; i < kGroupSize - 1; ++i) {
// was in descending order. Avoid div by 0. Ignore sign changes.
diffs[i] = std::abs(desc[i]) < 1e-5
? 0
: std::abs((desc[i] - desc[i + 1]) / desc[i]);
}
hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending());
s_cut15_.Notify(diffs[15]);
}
void Assimilate(const PerThread& other) {
num_tiny_ += other.num_tiny_;
s_all_.Assimilate(other.s_all_);
s_group_ranges_.Assimilate(other.s_group_ranges_);
s_group_mins_.Assimilate(other.s_group_mins_);
s_group_maxs_.Assimilate(other.s_group_maxs_);
s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_);
s_erange_.Assimilate(other.s_erange_);
s_km_1_.Assimilate(other.s_km_1_);
s_km_2_.Assimilate(other.s_km_2_);
s_cut15_.Assimilate(other.s_cut15_);
b_magn100_.Assimilate(other.b_magn100_);
b_exp256_.Assimilate(other.b_exp256_);
b_m4_.Assimilate(other.b_m4_);
}
void PrintAll() {
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "num tiny %zu\n", num_tiny_);
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str());
fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str());
fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str());
fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str());
fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str());
fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str());
fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str());
// b_magn100_.Print("magn100");
// b_exp256_.Print("exp");
// b_m4_.Print("mantissa bits4");
fprintf(stderr, "\n");
}
private:
size_t num_tiny_ = 0;
hwy::Stats s_all_;
hwy::Stats s_group_ranges_;
hwy::Stats s_group_mins_;
hwy::Stats s_group_maxs_;
hwy::Stats s_group_max_vs_min_;
hwy::Stats s_erange_;
hwy::Stats s_km_1_;
hwy::Stats s_km_2_;
hwy::Stats s_cut15_;
hwy::Bins<100> b_magn100_;
hwy::Bins<256> b_exp256_;
hwy::Bins<16> b_m4_;
uint8_t padding_[64]; // prevent false sharing
};
class PerLayer {
public:
void NotifyGroup(const float* group) {
for (size_t i = 0; i < NuqStream::kGroupSize; ++i) {
s_layer_.Notify(group[i]);
}
}
void UpdateOutliers(const float* layer, size_t weights_per_layer) {
const float layer_mean = s_layer_.Mean();
const float layer_sd = s_layer_.StandardDeviation();
for (size_t i = 0; i < weights_per_layer; ++i) {
num_outliers_ +=
std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd;
}
}
const hwy::Stats& GetStats() const { return s_layer_; }
size_t Outliers() const { return num_outliers_; }
private:
hwy::Stats s_layer_;
size_t num_outliers_ = 0;
uint8_t padding[64]; // prevent false sharing
};
static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
size_t weights_per_layer,
hwy::ThreadPool& pool) {
std::vector<PerThread> tls;
std::vector<PerLayer> per_layer(layers);
const auto init = [&](size_t num_threads) {
tls.resize(num_threads);
return true;
};
pool.Run(0, static_cast<uint32_t>(layers), init,
[&](uint32_t idx_layer, size_t idx_thread) {
PerThread& self = tls[idx_thread];
const float* layer = &mat[idx_layer * weights_per_layer];
// For each whole group in the layer
for (size_t group_start = 0;
group_start + NuqStream::kGroupSize <= weights_per_layer;
group_start += NuqStream::kGroupSize) {
const float* group = layer + group_start;
per_layer[idx_layer].NotifyGroup(group);
self.NotifyGroup(group);
}
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
});
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "\n------------%s\n", caption);
for (size_t i = 1; i < pool.NumWorkers(); ++i) {
tls[0].Assimilate(tls[i]);
}
tls[0].PrintAll();
hwy::Stats s_layer_ranges;
hwy::Stats s_layer_outliers;
for (size_t i = 0; i < layers; ++i) {
fprintf(stderr, " %02zu %s\n", i,
per_layer[i].GetStats().ToString(skip).c_str());
const float range =
per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min();
s_layer_ranges.Notify(range);
s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) /
weights_per_layer);
}
fprintf(stderr, "layer outliers%% %s\n",
s_layer_outliers.ToString(skip).c_str());
fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_

View File

@ -26,6 +26,7 @@
#include "gemma/configs.h" // ModelConfig
#include "gemma/gemma_args.h" // AttentionImpl
#include "gemma/kv_cache.h"
#include "gemma/tensor_stats.h"
#include "ops/ops.h" // CreateInvTimescale
#include "util/basics.h" // BF16
#include "util/mat.h" // MatStorageT
@ -268,6 +269,14 @@ struct Activations {
ffw_out(
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
max_workers(ctx.pools.MaxWorkers()),
s_ffw_in(config.num_layers, max_workers),
s_ffw_hidden(config.num_layers, max_workers),
s_ffw_out(config.num_layers, max_workers),
s_w_gating_einsum_w1(config.num_layers, max_workers),
s_w_gating_einsum_w2(config.num_layers, max_workers),
s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
runtime_config.attention_impl, ctx.allocator,
@ -288,6 +297,12 @@ struct Activations {
// Note that BindC on any MatMul output considerably slows down Prefill.
}
~Activations() {
s_ffw_in.ReduceAndPrint("ffw_in");
s_ffw_hidden.ReduceAndPrint("ffw_hidden");
s_ffw_out.ReduceAndPrint("ffw_out");
}
// Negligible CPU time.
void SetBatchSize(size_t batch_size) {
x.OverrideRows(batch_size);
@ -319,6 +334,15 @@ struct Activations {
MatStorageT<BF16> C2;
MatStorageT<float> ffw_out;
const size_t max_workers;
TensorStats s_ffw_in;
TensorStats s_ffw_hidden; // after Activation+gating
TensorStats s_ffw_out;
TensorStats s_w_gating_einsum_w1;
TensorStats s_w_gating_einsum_w2;
TensorStats s_w_linear_w;
AttentionImpl attention_impl;
AttentionActivations attention_storage;

View File

@ -20,6 +20,7 @@
#include "gemma/activations.h"
#include "gemma/configs.h"
#include "gemma/tensor_stats.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "util/mat.h"
@ -158,6 +159,9 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out,
env.ctx);
#if GEMMA_FUSED_FFN
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
StridedViewBF C2, size_t worker) {
@ -179,8 +183,12 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
env.ctx);
#endif
activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx);
// Hidden layer -> output layer.
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx);
}
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and

View File

@ -21,11 +21,13 @@
#include <optional>
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
#include "util/zones.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
#include "gemma/tensor_stats.h"
#include "util/zones.h"
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
@ -568,12 +570,26 @@ static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
config, runtime_config, qbatch, update_pos, non_eos);
}
void SetWeightStats(const LayerWeightsPtrs& layer, Activations& a,
ThreadingContext& ctx) {
const size_t layer_idx = layer.layer_idx;
a.s_w_gating_einsum_w1.Notify(layer_idx, layer.gating_einsum_w1, ctx,
kTensorStatsIsWeight);
a.s_w_gating_einsum_w2.Notify(layer_idx, layer.gating_einsum_w2, ctx,
kTensorStatsIsWeight);
a.s_w_linear_w.Notify(layer_idx, layer.linear_w, ctx, kTensorStatsIsWeight);
}
// Decode: generates one continuation token for each query in `qbatch`.
static void GenerateT(const ModelConfig& config,
const RuntimeConfig& runtime_config,
const AesCtrEngine& engine, const WeightsPtrs& weights,
Activations& activations, QBatch& qbatch, MatMulEnv& env,
TimingInfo& timing_info) {
for (const LayerWeightsPtrs& layer : weights.c_layers) {
SetWeightStats(layer, activations, env.ctx);
}
const size_t max_gen_steps = PrefillTBatchOrQBatch(
config, runtime_config, weights, activations, qbatch, env, timing_info);

View File

@ -1,3 +1,18 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/tensor_info.h"
#include <stddef.h>

View File

@ -1,3 +1,18 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_

205
gemma/tensor_stats.cc Normal file
View File

@ -0,0 +1,205 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "gemma/tensor_stats.h"
#if GCPP_TENSOR_STATS
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <atomic>
#include <cmath>
#include <memory>
#include "io/io.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "util/zones.h"
#include "hwy/profiler.h" // StringTable
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/tensor_stats.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
float Correlation(const float* x, size_t num) {
double sum = 0.0;
for (size_t i = 0; i < num; ++i) {
sum += x[i];
}
const double mean = sum / static_cast<double>(num);
double numerator = 0.0;
double sum_sq_current = 0.0;
double sum_sq_next = 0.0;
for (size_t i = 0; i < num - 1; ++i) {
const double diff_current = static_cast<double>(x[i]) - mean;
const double diff_next = static_cast<double>(x[i + 1]) - mean;
numerator += diff_current * diff_next;
sum_sq_current += diff_current * diff_current;
sum_sq_next += diff_next * diff_next;
}
if (sum_sq_current == 0.0 || sum_sq_next == 0.0) return 0.0f;
const double denominator = std::sqrt(sum_sq_current * sum_sq_next);
const float corr = static_cast<float>(numerator / denominator);
HWY_DASSERT(-1.0f <= corr && corr <= 1.0f);
return corr;
}
// Only write tensor data the first time it is encountered per layer. This is
// a concurrent string+layer -> flag map which avoids std::mutex (incompatible
// with fibers). We use a string table to index into per-layer atomic flags.
static bool ShouldWrite(const char* name, size_t layer_idx) {
constexpr size_t kMaxNames = 128;
constexpr size_t kMaxLayers = 128;
HWY_DASSERT(layer_idx < kMaxLayers);
static hwy::StringTable<kMaxNames> s_table;
const size_t name_idx = s_table.Add(name);
static std::atomic_flag flags[kMaxNames * kMaxLayers] = {};
return !flags[name_idx * kMaxLayers + layer_idx].test_and_set(
std::memory_order_acq_rel);
}
std::unique_ptr<File> MaybeOpenFile(size_t layer_idx, const MatPtr& type_erased,
const Path& tensor_output) {
if (tensor_output.Empty()) return nullptr;
if (!ShouldWrite(type_erased.Name(), layer_idx)) return nullptr;
char path[1024];
snprintf(path, sizeof(path), "%s/%s_L%02zu_%zux%zu_%s.bin",
tensor_output.path.c_str(), type_erased.Name(), layer_idx,
type_erased.Rows(), type_erased.Cols(),
TypeName(type_erased.GetType()));
return OpenFileOrAbort(Path(path), "wb");
}
void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
size_t row_idx) {
if (!file) return;
const size_t bytes_per_row = type_erased.Cols() * type_erased.ElementBytes();
file->Write(type_erased.RowBytes(row_idx), bytes_per_row,
bytes_per_row * row_idx);
}
// First dispatch to the type, then parallel over rows, then vectorized
// decompress and Notify for each value.
void UpdateStatsT(TensorStats& stats, size_t layer_idx,
const MatPtr& type_erased, ThreadingContext& ctx, int flags,
size_t cluster_idx, Parallelism parallelism) {
std::unique_ptr<File> file =
MaybeOpenFile(layer_idx, type_erased, ctx.tensor_output);
if ((flags & kTensorStatsIsWeight) && layer_idx != 0) {
// Still compute stats, but remember not to print them.
stats.Get(layer_idx, 0).DoNotPrint();
}
CallUpcasted(&type_erased, [&](const auto* mat) {
const size_t cols = mat->Cols();
ParallelFor(
parallelism, mat->Rows(), ctx, cluster_idx, Callers::kTensorStats,
[&](size_t row_idx, size_t global_idx) {
GCPP_ZONE(ctx, global_idx, Zones::kGenStats);
auto* HWY_RESTRICT row = mat->Row(row_idx);
MaybeWriteRow(file, type_erased, row_idx);
using Packed = hwy::RemoveCvRef<decltype(*row)>;
PackedSpan<Packed> packed(const_cast<Packed*>(row), cols);
TensorStatsAccumulator& my_stats = stats.Get(layer_idx, global_idx);
my_stats.NotifyCond(ConditionNumber(row, cols));
namespace hn = hwy::HWY_NAMESPACE;
hn::ScalableTag<float> df;
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
size_t packed_ofs = 0;
if (cols >= 2 * NF) {
for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
VF v0, v1;
Decompress2(df, packed, packed_ofs, v0, v1);
hn::Store(v0, df, buf);
hn::Store(v1, df, buf + NF);
const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1));
const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1));
const float min = hn::ReduceMin(df, min_mag);
if (min != 0.0f) { // Avoid division by zero.
my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag));
}
for (size_t i = 0; i < 2 * NF; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
}
my_stats.NotifyCorr(Correlation(buf, 2 * NF));
}
}
// Zero to two vectors remaining.
for (; packed_ofs < cols; packed_ofs += NF) {
const size_t remaining = HWY_MIN(NF, cols - packed_ofs);
DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining);
// Skip NotifyGroup for this partial group.
for (size_t i = 0; i < remaining; ++i) {
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
}
my_stats.NotifyCorr(Correlation(buf, remaining));
}
});
});
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
HWY_EXPORT(UpdateStatsT);
// Must reside in .cc file so that we can #include compress-inl.h.
void TensorStats::Notify(size_t layer_idx, const MatPtr& type_erased,
ThreadingContext& ctx, int flags, size_t cluster_idx,
Parallelism parallelism) {
// Ignore empty tensors.
if (type_erased.GetType() == Type::kUnknown || type_erased.Cols() == 0) {
return;
}
HWY_DYNAMIC_DISPATCH(UpdateStatsT)(*this, layer_idx, type_erased, ctx, flags,
cluster_idx, parallelism);
}
} // namespace gcpp
#endif // HWY_ONCE
#endif // GCPP_TENSOR_STATS

347
gemma/tensor_stats.h Normal file
View File

@ -0,0 +1,347 @@
// Copyright 2025 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include "util/basics.h"
#include "hwy/base.h"
#ifndef GCPP_TENSOR_STATS
#define GCPP_TENSOR_STATS 0
#endif
#include "util/mat.h"
#include "util/threading_context.h"
#if GCPP_TENSOR_STATS
#include <cmath>
#include <vector>
#include "hwy/stats.h"
#endif // GCPP_TENSOR_STATS
namespace gcpp {
// For flags. Used to inhibit printing per-layer stats for weights.
HWY_INLINE_VAR constexpr int kTensorStatsIsWeight = 1;
#if GCPP_TENSOR_STATS
HWY_INLINE_VAR constexpr size_t kStatsMaxCols = 8192;
// Separate summary of the per-layer stats, updated by `TensorStatsAccumulator`.
// We pass per-layer statistics such as the mean value to `hwy::Stats::Notify``
// to see the distribution of per-layer means.
struct TensorStatsAcrossLayers {
bool IsEmpty() const { return s_frobenius.Count() == 0; }
void Print() {
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "frob %s\n", s_frobenius.ToString(skip).c_str());
fprintf(stderr, "cnd.min %s\n", s_cond_min.ToString(skip).c_str());
fprintf(stderr, "cnd.avg %s\n", s_cond_avg.ToString(skip).c_str());
fprintf(stderr, "cnd.max %s\n", s_cond_max.ToString(skip).c_str());
fprintf(stderr, "val.min %s\n", s_val_min.ToString(skip).c_str());
fprintf(stderr, "val.avg %s\n", s_val_avg.ToString(skip).c_str());
fprintf(stderr, "val.krt %s\n", s_val_kurt.ToString(skip).c_str());
fprintf(stderr, "mag.min %s\n", s_mag_min.ToString(skip).c_str());
fprintf(stderr, "mag.avg %s\n", s_mag_avg.ToString(skip).c_str());
fprintf(stderr, "mag.max %s\n", s_mag_max.ToString(skip).c_str());
if (hwy::ScalarAbs(s_corr_avg.Max()) > 0.05f) {
fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str());
}
fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str());
fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str());
fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str());
fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str());
fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str());
if (s_exp_subnormal.Min() != 0.0f) {
fprintf(stderr, "exp.sub %s\n", s_exp_subnormal.ToString(skip).c_str());
}
if (s_big_cols.Count() != 0) {
fprintf(stderr, "bigCols %s\n", s_big_cols.ToString(skip).c_str());
const size_t modal_col = b_big_cols.ModalBinIdx();
const size_t num_outlier_cols = b_big_cols.NumNonzero();
if (num_outlier_cols > 256) {
fprintf(stderr, "bigCols: all up to %zu (max at %zu: %u layers):\n",
b_big_cols.LastNonzero(), modal_col, b_big_cols.Bin(modal_col));
} else {
fprintf(stderr, "bigCols (max at %zu: %u layers):\n", modal_col,
b_big_cols.Bin(modal_col));
for (size_t i = 0; i < kStatsMaxCols; ++i) {
if (b_big_cols.Bin(i) > 2) {
fprintf(stderr, " %3zu: %u\n", i, b_big_cols.Bin(i));
}
}
}
}
fprintf(stderr, "\n");
}
hwy::Stats s_frobenius;
hwy::Stats s_cond_min;
hwy::Stats s_cond_avg;
hwy::Stats s_cond_max;
hwy::Stats s_val_min;
hwy::Stats s_val_avg;
hwy::Stats s_val_kurt;
hwy::Stats s_mag_min;
hwy::Stats s_mag_avg;
hwy::Stats s_mag_max;
hwy::Stats s_corr_avg;
hwy::Stats s_corr_max;
hwy::Stats s_range_avg;
hwy::Stats s_exp_min;
hwy::Stats s_exp_max;
hwy::Stats s_exp_mode;
hwy::Stats s_exp_subnormal;
hwy::Stats s_big_cols; // total number of outlier cols
hwy::Bins<kStatsMaxCols> b_big_cols; // # layers with outlier per col
};
// Per-thread and layer.
class TensorStatsAccumulator {
public:
void Notify(float val, size_t row_idx, size_t col_idx) {
const double dval = static_cast<double>(val);
sum_sq_ += dval * dval;
s_val_.Notify(val);
const float mag = hwy::ScalarAbs(val);
if (HWY_UNLIKELY(mag >= 64.0f)) {
if (row_idx < kMaxBatchSize) b_big_row_.Notify(row_idx);
if (col_idx < kStatsMaxCols) b_big_col_.Notify(col_idx);
}
// Skip zero so we can see the lowest actual magnitude
if (mag != 0.0f && mag != -0.0f) s_mag_.Notify(mag);
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(mag);
// Use biased exponent because Bins wants unsigned values.
const uint32_t biased_exp = binary32 >> 23;
HWY_DASSERT(biased_exp < 256); // already cleared sign bit
b_exp256_.Notify(biased_exp);
}
void DoNotPrint() { skip_.fetch_or(1); }
bool ShouldPrint() const { return skip_.load() == 0; }
// Vector code computed the min/max of a group (= two vectors); this is
// faster than doing it in `Notify`.
void NotifyGroup(float min, float max) {
s_group_min_.Notify(min);
s_group_max_.Notify(max);
// Caller ensures min != 0.
s_group_range_.Notify(max / min);
}
void NotifyCorr(float corr) { s_corr_.Notify(corr); }
void NotifyCond(double cond) { s_cond_.Notify(cond); }
void Assimilate(const TensorStatsAccumulator& other) {
skip_.fetch_or(other.skip_.load());
sum_sq_ += other.sum_sq_;
b_exp256_.Assimilate(other.b_exp256_);
b_big_row_.Assimilate(other.b_big_row_);
b_big_col_.Assimilate(other.b_big_col_);
s_val_.Assimilate(other.s_val_);
s_mag_.Assimilate(other.s_mag_);
s_corr_.Assimilate(other.s_corr_);
s_group_min_.Assimilate(other.s_group_min_);
s_group_max_.Assimilate(other.s_group_max_);
s_group_range_.Assimilate(other.s_group_range_);
}
// Called on the per-layer representative after reducing across threads.
void NotifyAcrossLayer(TensorStatsAcrossLayers& s) {
s.s_frobenius.Notify(std::sqrt(sum_sq_));
s.s_cond_min.Notify(s_cond_.Min());
s.s_cond_avg.Notify(s_cond_.Mean());
s.s_cond_max.Notify(s_cond_.Max());
s.s_val_min.Notify(s_val_.Min());
s.s_val_avg.Notify(s_val_.Mean());
s.s_val_kurt.Notify(s_val_.Kurtosis());
s.s_mag_min.Notify(s_mag_.Min());
s.s_mag_avg.Notify(s_mag_.Mean());
s.s_mag_max.Notify(s_mag_.Max());
s.s_corr_avg.Notify(s_corr_.Mean());
s.s_corr_max.Notify(s_corr_.Max());
s.s_range_avg.Notify(s_group_range_.Mean());
const uint32_t subnormals = b_exp256_.Bin(0);
// Prevent subnormals from hiding the min exponent.
b_exp256_.ResetBin(0);
s.s_exp_min.Notify(b_exp256_.FirstNonzero());
s.s_exp_max.Notify(b_exp256_.LastNonzero());
s.s_exp_mode.Notify(b_exp256_.ModalBinIdx());
s.s_exp_subnormal.Notify(subnormals);
const uint32_t num_outliers = b_big_col_.NumNonzero();
if (num_outliers != 0) {
s.s_big_cols.Notify(num_outliers);
// For each col, count the number of layers that have an outlier there.
for (size_t i = 0; i < kStatsMaxCols; ++i) {
if (b_big_col_.Bin(i) != 0) s.b_big_cols.Notify(i);
}
}
}
bool IsEmpty() const { return s_val_.Count() == 0; }
void PrintAll() {
fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_));
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str());
fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str());
fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str());
fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str());
b_exp256_.Print("exp");
PrintBinRanges(b_big_row_, "big row");
PrintBinRanges(b_big_col_, "big col");
fprintf(stderr, "\n");
}
private:
template <size_t N>
void PrintBinRanges(const hwy::Bins<N>& b, const char* name) {
uint64_t total = 0;
for (size_t i = 0; i < N; ++i) {
total += b.Bin(i);
}
if (total == 0) return;
// If all bins are at least 10% of a uniform distribution, print the range
// to vastly reduce the log size.
const size_t min = HWY_MAX(1, total / (N * 10));
size_t last = 0;
for (; last < N; ++last) {
if (b.Bin(last) < min) break;
}
if (last >= N / 2) {
// Also require all subsequent bins to be zero, otherwise we should
// print the outlier bins.
bool all_zero = true;
for (size_t i = last + 1; i < N; ++i) {
if (b.Bin(last) != 0) {
all_zero = false;
break;
}
}
if (all_zero) {
fprintf(stderr, "%s: uniform up to %zu\n", name, last);
return;
}
}
b.Print(name, /*skip_zero=*/true);
}
double sum_sq_ = 0.0; // for Frobenius norm
hwy::Bins<256> b_exp256_; // exponent
hwy::Bins<kMaxBatchSize> b_big_row_;
hwy::Bins<kStatsMaxCols> b_big_col_;
hwy::Stats s_val_;
hwy::Stats s_mag_;
hwy::Stats s_cond_; // condition number
hwy::Stats s_corr_; // lag-1 autocorrelation
hwy::Stats s_group_min_;
hwy::Stats s_group_max_;
hwy::Stats s_group_range_;
std::atomic<int> skip_{0};
};
class TensorStats {
public:
TensorStats(size_t num_layers, size_t max_workers)
: num_layers_(num_layers),
max_workers_(max_workers),
acc_(num_layers * max_workers) {}
// Parallelized across rows. If `ctx.tensor_output` is not empty, writes
// tensor data to disk for offline analysis, once per tensor and layer.
void Notify(size_t layer_idx, const MatPtr& type_erased,
ThreadingContext& ctx, int flags = 0, size_t cluster_idx = 0,
Parallelism parallelism = Parallelism::kFlat);
// For use by `UpdateStatsT`.
TensorStatsAccumulator& Get(size_t layer_idx, size_t global_idx) {
const size_t idx = layer_idx * max_workers_ + global_idx;
HWY_DASSERT(idx < acc_.size());
return acc_[idx];
}
void ReduceAndPrint(const char* prefix) {
for (size_t layer_idx = 0; layer_idx < num_layers_; ++layer_idx) {
TensorStatsAccumulator& per_layer = Get(layer_idx, 0);
for (size_t global_idx = 1; global_idx < max_workers_; ++global_idx) {
per_layer.Assimilate(Get(layer_idx, global_idx));
}
if (per_layer.IsEmpty()) continue;
per_layer.NotifyAcrossLayer(across_layers_);
if (per_layer.ShouldPrint()) {
fprintf(stderr, "-------------------- %s %zu\n", prefix, layer_idx);
per_layer.PrintAll();
}
}
if (!across_layers_.IsEmpty()) {
fprintf(stderr, "================= across layers %s\n", prefix);
across_layers_.Print();
}
}
private:
size_t num_layers_;
size_t max_workers_;
std::vector<TensorStatsAccumulator> acc_;
TensorStatsAcrossLayers across_layers_;
};
#else // GCPP_TENSOR_STATS
class TensorStats {
public:
TensorStats(size_t, size_t) {}
void Notify(size_t, const MatPtr&, ThreadingContext&, int = 0, size_t = 0,
Parallelism = Parallelism::kFlat) {}
void ReduceAndPrint(const char*) {}
};
#endif // GCPP_TENSOR_STATS
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_

View File

@ -96,7 +96,8 @@ struct LayerWeightsPtrs {
// other values for purposes of the KV cache.
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
const TensorInfoRegistry& tensors)
: finder_(LayerSuffix(layer_idx), tensors),
: layer_idx(layer_idx),
finder_(LayerSuffix(layer_idx), tensors),
qkv_einsum_w(finder_("qkv_ein")),
qkv_einsum_w1(finder_("qkv1_w")),
qkv_einsum_w2(finder_("qkv2_w")),
@ -135,6 +136,7 @@ struct LayerWeightsPtrs {
}
~LayerWeightsPtrs() = default;
const size_t layer_idx;
const MatFinder finder_;
// Files either have qkv_einsum_w with 2 stacked matrices or separate

View File

@ -196,9 +196,9 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
namespace gcpp {
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
std::unique_ptr<File> file = OpenFileOrNull(filename, mode);
if (!file) {
HWY_ABORT("Failed to open %s", filename.path.c_str());
HWY_ABORT("Failed to open %s, errno %d", filename.path.c_str(), errno);
}
return file;
}

View File

@ -97,7 +97,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args)
BoundedSlice(args.skip_lps, args.max_lps)),
cache_info(topology),
allocator(topology, cache_info, args.bind != Tristate::kFalse),
pools(topology, allocator, args.max_threads, args.pin) {
pools(topology, allocator, args.max_threads, args.pin),
tensor_output(args.tensor_output) {
PROFILER_ZONE("Startup.ThreadingContext autotune");
TunePools(hwy::PoolWaitMode::kSpin, *this);
// kBlock is the default, hence set/tune it last.

View File

@ -23,6 +23,7 @@
#include <stdint.h>
// IWYU pragma: begin_exports
#include "io/io.h" // Path
#include "util/allocator.h"
#include "util/args.h"
#include "util/basics.h" // Tristate
@ -55,6 +56,8 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
Tristate pin; // pin threads?
Tristate spin; // use spin waits?
Path tensor_output; // empty, or directory for tensor output
template <class Visitor>
void ForEach(const Visitor& visitor) {
// These can be used to partition CPU packages/sockets and their
@ -85,6 +88,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
visitor(bind, "bind", Tristate::kDefault,
"Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(tensor_output, "tensor_output", Path(),
"Empty, or directory for tensor output.", 2);
}
};
@ -124,6 +130,8 @@ struct ThreadingContext {
// Per-package/cluster/within cluster pools of threads, matching `topology`.
NestedPools pools;
Path tensor_output; // used by `TensorStats::Notify`.
};
#define GCPP_ZONE(ctx, global_idx, zone_enum) \

View File

@ -51,6 +51,8 @@ const char* ZoneName(Zones zone) {
return "Gen.SampleTop1";
case Zones::kGenSampleTopK:
return "Gen.SampleTopK";
case Zones::kGenStats:
return "Gen.Stats";
case Zones::kMMDecompressA:
return "MM.DecompressA";
case Zones::kMMDispatch:
@ -163,6 +165,8 @@ const char* CallerName(Callers caller) {
return "ReadBatches";
case Callers::kSampleAndStream:
return "SampleAndStream";
case Callers::kTensorStats:
return "TensorStats";
case Callers::kTest: // only for unit tests.
return "Test-only!";
case Callers::kTunePool:

View File

@ -31,6 +31,7 @@ enum class Zones { // Keep sorted
kGenFFW,
kGenSampleTop1,
kGenSampleTopK,
kGenStats,
kMMDecompressA,
kMMDispatch,
kMMMatMul,
@ -96,6 +97,7 @@ enum class Callers { // Keep sorted
kReadAllToBF16,
kReadBatches,
kSampleAndStream,
kTensorStats,
kTest, // only for unit tests.
kTunePool,
kVitDotSoftmax1,