mirror of https://github.com/google/gemma.cpp.git
Add tensor stats and output
tensor_info: add missing header io: fix mode weights.h: add layer_idx to LayerWeightsPtrs PiperOrigin-RevId: 843531051
This commit is contained in:
parent
bfc0dfcfca
commit
73c3627b67
|
|
@ -112,6 +112,7 @@ cc_library(
|
|||
":threading",
|
||||
":topology",
|
||||
":zones",
|
||||
"//io",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:profiler",
|
||||
|
|
@ -555,6 +556,7 @@ cc_library(
|
|||
"gemma/attention.cc",
|
||||
"gemma/flash_attention.cc",
|
||||
"gemma/gemma.cc",
|
||||
"gemma/tensor_stats.cc",
|
||||
"gemma/vit.cc",
|
||||
],
|
||||
hdrs = [
|
||||
|
|
@ -563,6 +565,7 @@ cc_library(
|
|||
"gemma/flash_attention.h",
|
||||
"gemma/flash_structs.h",
|
||||
"gemma/gemma.h",
|
||||
"gemma/tensor_stats.h",
|
||||
"gemma/vit.h",
|
||||
],
|
||||
exec_properties = {
|
||||
|
|
@ -597,6 +600,7 @@ cc_library(
|
|||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
] +
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
# Weight compression and analysis.
|
||||
# Compressed tensor types.
|
||||
|
||||
load("@rules_cc//cc:cc_library.bzl", "cc_library")
|
||||
load("@rules_cc//cc:cc_test.bzl", "cc_test")
|
||||
|
|
@ -208,19 +208,3 @@ cc_test(
|
|||
"@highway//:hwy_test_util",
|
||||
],
|
||||
)
|
||||
|
||||
# For internal experimentation
|
||||
cc_library(
|
||||
name = "analyze",
|
||||
textual_hdrs = ["analyze.h"],
|
||||
deps = [
|
||||
":int",
|
||||
":nuq",
|
||||
":sfp",
|
||||
":types",
|
||||
"@highway//:hwy",
|
||||
"@highway//:stats",
|
||||
"@highway//:thread_pool",
|
||||
"@highway//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,238 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Normal include guard to placate lint.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
#include <string.h> // memcpy
|
||||
|
||||
#include <cmath> // std::signbit
|
||||
#include <cstdlib> // std::abs
|
||||
#include <vector>
|
||||
|
||||
#include "compression/types.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/stats.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
||||
// Actual per-target include guard.
|
||||
#if defined(THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||
#ifdef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#undef THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#else
|
||||
#define THIRD_PARTY_GEMMA_CPP_ANALYZE_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "compression/nuq-inl.h"
|
||||
#include "compression/sfp-inl.h"
|
||||
#include "hwy/contrib/sort/vqsort-inl.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
class PerThread {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
constexpr size_t kGroupSize = NuqStream::kGroupSize;
|
||||
hwy::Stats s_group;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
// Skip zero so we can see the lowest actual magnitude
|
||||
if (group[i] == 0.0f || group[i] == -0.0f) continue;
|
||||
s_all_.Notify(group[i]);
|
||||
s_group.Notify(group[i]);
|
||||
|
||||
num_tiny_ += std::abs(group[i]) < 1e-3f;
|
||||
|
||||
// b_magn100_.Notify(group[i] * 40.0f + 20.0f);
|
||||
const uint32_t binary32 =
|
||||
hwy::BitCastScalar<uint32_t>(std::abs(group[i]));
|
||||
|
||||
// const int32_t exp = (binary32 >> 23) - 127;
|
||||
b_exp256_.Notify(binary32 >> 23);
|
||||
const uint32_t m4 = (binary32 & 0x7FFFFF) >> (23 - 4);
|
||||
b_m4_.Notify(m4);
|
||||
}
|
||||
s_group_ranges_.Notify(s_group.Max() - s_group.Min());
|
||||
s_group_mins_.Notify(s_group.Min());
|
||||
s_group_maxs_.Notify(s_group.Max());
|
||||
|
||||
float desc[kGroupSize];
|
||||
memcpy(desc, group, kGroupSize * sizeof(group[0]));
|
||||
hn::VQSortStatic(desc, kGroupSize, hwy::SortDescending());
|
||||
|
||||
// Find largest |max/min| (dynamic range)
|
||||
float max_ratio = 0.0f;
|
||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||
if (desc[i] != 0.0f && desc[i] != -0.0f) {
|
||||
max_ratio = std::max(max_ratio, std::abs(desc[0] / desc[i]));
|
||||
}
|
||||
}
|
||||
s_group_max_vs_min_.Notify(max_ratio);
|
||||
|
||||
// Relative errors
|
||||
float diffs[kGroupSize];
|
||||
for (size_t i = 0; i < kGroupSize - 1; ++i) {
|
||||
// was in descending order. Avoid div by 0. Ignore sign changes.
|
||||
diffs[i] = std::abs(desc[i]) < 1e-5
|
||||
? 0
|
||||
: std::abs((desc[i] - desc[i + 1]) / desc[i]);
|
||||
}
|
||||
hn::VQSortStatic(diffs, kGroupSize, hwy::SortDescending());
|
||||
s_cut15_.Notify(diffs[15]);
|
||||
}
|
||||
|
||||
void Assimilate(const PerThread& other) {
|
||||
num_tiny_ += other.num_tiny_;
|
||||
s_all_.Assimilate(other.s_all_);
|
||||
s_group_ranges_.Assimilate(other.s_group_ranges_);
|
||||
s_group_mins_.Assimilate(other.s_group_mins_);
|
||||
s_group_maxs_.Assimilate(other.s_group_maxs_);
|
||||
s_group_max_vs_min_.Assimilate(other.s_group_max_vs_min_);
|
||||
s_erange_.Assimilate(other.s_erange_);
|
||||
s_km_1_.Assimilate(other.s_km_1_);
|
||||
s_km_2_.Assimilate(other.s_km_2_);
|
||||
s_cut15_.Assimilate(other.s_cut15_);
|
||||
b_magn100_.Assimilate(other.b_magn100_);
|
||||
b_exp256_.Assimilate(other.b_exp256_);
|
||||
b_m4_.Assimilate(other.b_m4_);
|
||||
}
|
||||
|
||||
void PrintAll() {
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "num tiny %zu\n", num_tiny_);
|
||||
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
|
||||
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
|
||||
fprintf(stderr, " mins %s\n", s_group_mins_.ToString(skip).c_str());
|
||||
fprintf(stderr, " maxs %s\n", s_group_maxs_.ToString(skip).c_str());
|
||||
fprintf(stderr, " Mvm %s\n", s_group_max_vs_min_.ToString(skip).c_str());
|
||||
fprintf(stderr, " cut15 %s\n", s_cut15_.ToString(skip).c_str());
|
||||
fprintf(stderr, " erange %s\n", s_erange_.ToString(skip).c_str());
|
||||
fprintf(stderr, " km1 %s\n", s_km_1_.ToString(skip).c_str());
|
||||
fprintf(stderr, " km2 %s\n", s_km_2_.ToString(skip).c_str());
|
||||
|
||||
// b_magn100_.Print("magn100");
|
||||
// b_exp256_.Print("exp");
|
||||
// b_m4_.Print("mantissa bits4");
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_tiny_ = 0;
|
||||
hwy::Stats s_all_;
|
||||
hwy::Stats s_group_ranges_;
|
||||
hwy::Stats s_group_mins_;
|
||||
hwy::Stats s_group_maxs_;
|
||||
hwy::Stats s_group_max_vs_min_;
|
||||
hwy::Stats s_erange_;
|
||||
hwy::Stats s_km_1_;
|
||||
hwy::Stats s_km_2_;
|
||||
hwy::Stats s_cut15_;
|
||||
hwy::Bins<100> b_magn100_;
|
||||
hwy::Bins<256> b_exp256_;
|
||||
hwy::Bins<16> b_m4_;
|
||||
uint8_t padding_[64]; // prevent false sharing
|
||||
};
|
||||
|
||||
class PerLayer {
|
||||
public:
|
||||
void NotifyGroup(const float* group) {
|
||||
for (size_t i = 0; i < NuqStream::kGroupSize; ++i) {
|
||||
s_layer_.Notify(group[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateOutliers(const float* layer, size_t weights_per_layer) {
|
||||
const float layer_mean = s_layer_.Mean();
|
||||
const float layer_sd = s_layer_.StandardDeviation();
|
||||
for (size_t i = 0; i < weights_per_layer; ++i) {
|
||||
num_outliers_ +=
|
||||
std::abs(std::abs(layer[i]) - layer_mean) >= 3.0f * layer_sd;
|
||||
}
|
||||
}
|
||||
|
||||
const hwy::Stats& GetStats() const { return s_layer_; }
|
||||
size_t Outliers() const { return num_outliers_; }
|
||||
|
||||
private:
|
||||
hwy::Stats s_layer_;
|
||||
size_t num_outliers_ = 0;
|
||||
uint8_t padding[64]; // prevent false sharing
|
||||
};
|
||||
|
||||
static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
||||
size_t weights_per_layer,
|
||||
hwy::ThreadPool& pool) {
|
||||
std::vector<PerThread> tls;
|
||||
std::vector<PerLayer> per_layer(layers);
|
||||
const auto init = [&](size_t num_threads) {
|
||||
tls.resize(num_threads);
|
||||
return true;
|
||||
};
|
||||
|
||||
pool.Run(0, static_cast<uint32_t>(layers), init,
|
||||
[&](uint32_t idx_layer, size_t idx_thread) {
|
||||
PerThread& self = tls[idx_thread];
|
||||
const float* layer = &mat[idx_layer * weights_per_layer];
|
||||
// For each whole group in the layer
|
||||
for (size_t group_start = 0;
|
||||
group_start + NuqStream::kGroupSize <= weights_per_layer;
|
||||
group_start += NuqStream::kGroupSize) {
|
||||
const float* group = layer + group_start;
|
||||
per_layer[idx_layer].NotifyGroup(group);
|
||||
self.NotifyGroup(group);
|
||||
}
|
||||
|
||||
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
|
||||
});
|
||||
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "\n------------%s\n", caption);
|
||||
|
||||
for (size_t i = 1; i < pool.NumWorkers(); ++i) {
|
||||
tls[0].Assimilate(tls[i]);
|
||||
}
|
||||
tls[0].PrintAll();
|
||||
|
||||
hwy::Stats s_layer_ranges;
|
||||
hwy::Stats s_layer_outliers;
|
||||
for (size_t i = 0; i < layers; ++i) {
|
||||
fprintf(stderr, " %02zu %s\n", i,
|
||||
per_layer[i].GetStats().ToString(skip).c_str());
|
||||
const float range =
|
||||
per_layer[i].GetStats().Max() - per_layer[i].GetStats().Min();
|
||||
s_layer_ranges.Notify(range);
|
||||
s_layer_outliers.Notify((100.0 * per_layer[i].Outliers()) /
|
||||
weights_per_layer);
|
||||
}
|
||||
fprintf(stderr, "layer outliers%% %s\n",
|
||||
s_layer_outliers.ToString(skip).c_str());
|
||||
fprintf(stderr, "layer ranges %s\n", s_layer_ranges.ToString(skip).c_str());
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||
|
|
@ -26,6 +26,7 @@
|
|||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/gemma_args.h" // AttentionImpl
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "gemma/tensor_stats.h"
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
|
@ -268,6 +269,14 @@ struct Activations {
|
|||
ffw_out(
|
||||
MatFactory("ffw_out", batch_size, config.model_dim, ctx.allocator)),
|
||||
|
||||
max_workers(ctx.pools.MaxWorkers()),
|
||||
s_ffw_in(config.num_layers, max_workers),
|
||||
s_ffw_hidden(config.num_layers, max_workers),
|
||||
s_ffw_out(config.num_layers, max_workers),
|
||||
|
||||
s_w_gating_einsum_w1(config.num_layers, max_workers),
|
||||
s_w_gating_einsum_w2(config.num_layers, max_workers),
|
||||
s_w_linear_w(config.num_layers, max_workers),
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
runtime_config.attention_impl, ctx.allocator,
|
||||
|
|
@ -288,6 +297,12 @@ struct Activations {
|
|||
// Note that BindC on any MatMul output considerably slows down Prefill.
|
||||
}
|
||||
|
||||
~Activations() {
|
||||
s_ffw_in.ReduceAndPrint("ffw_in");
|
||||
s_ffw_hidden.ReduceAndPrint("ffw_hidden");
|
||||
s_ffw_out.ReduceAndPrint("ffw_out");
|
||||
}
|
||||
|
||||
// Negligible CPU time.
|
||||
void SetBatchSize(size_t batch_size) {
|
||||
x.OverrideRows(batch_size);
|
||||
|
|
@ -319,6 +334,15 @@ struct Activations {
|
|||
MatStorageT<BF16> C2;
|
||||
MatStorageT<float> ffw_out;
|
||||
|
||||
const size_t max_workers;
|
||||
TensorStats s_ffw_in;
|
||||
TensorStats s_ffw_hidden; // after Activation+gating
|
||||
TensorStats s_ffw_out;
|
||||
|
||||
TensorStats s_w_gating_einsum_w1;
|
||||
TensorStats s_w_gating_einsum_w2;
|
||||
TensorStats s_w_linear_w;
|
||||
|
||||
AttentionImpl attention_impl;
|
||||
|
||||
AttentionActivations attention_storage;
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@
|
|||
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_stats.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/mat.h"
|
||||
|
|
@ -158,6 +159,9 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
|
||||
HWY_DASSERT(!layer_config.ff_biases); // Only used in Vit.
|
||||
|
||||
activations.s_ffw_in.Notify(layer.layer_idx, activations.pre_ffw_rms_out,
|
||||
env.ctx);
|
||||
|
||||
#if GEMMA_FUSED_FFN
|
||||
const auto fused = [&](RowPtrsBF C1, IndexRange range_r, IndexRange range_c,
|
||||
StridedViewBF C2, size_t worker) {
|
||||
|
|
@ -179,8 +183,12 @@ static inline void FFWNoVit(const LayerWeightsPtrs& layer,
|
|||
env.ctx);
|
||||
#endif
|
||||
|
||||
activations.s_ffw_hidden.Notify(layer.layer_idx, activations.C1, env.ctx);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
CallMatMul(activations.C1, layer.linear_w, nullptr, env, activations.ffw_out);
|
||||
|
||||
activations.s_ffw_out.Notify(layer.layer_idx, activations.ffw_out, env.ctx);
|
||||
}
|
||||
|
||||
// Sums encoded (`att_out`) over num_heads (`layer_config.heads`) and
|
||||
|
|
|
|||
|
|
@ -21,11 +21,13 @@
|
|||
#include <optional>
|
||||
|
||||
#include "compression/types.h" // GEMMA_DISABLED_TARGETS
|
||||
#include "util/zones.h"
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
|
||||
#endif // HWY_DISABLED_TARGETS
|
||||
|
||||
#include "gemma/tensor_stats.h"
|
||||
#include "util/zones.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
|
|
@ -568,12 +570,26 @@ static void StreamAndUpdateEOSAfterPrefill(const ModelConfig& config,
|
|||
config, runtime_config, qbatch, update_pos, non_eos);
|
||||
}
|
||||
|
||||
void SetWeightStats(const LayerWeightsPtrs& layer, Activations& a,
|
||||
ThreadingContext& ctx) {
|
||||
const size_t layer_idx = layer.layer_idx;
|
||||
a.s_w_gating_einsum_w1.Notify(layer_idx, layer.gating_einsum_w1, ctx,
|
||||
kTensorStatsIsWeight);
|
||||
a.s_w_gating_einsum_w2.Notify(layer_idx, layer.gating_einsum_w2, ctx,
|
||||
kTensorStatsIsWeight);
|
||||
a.s_w_linear_w.Notify(layer_idx, layer.linear_w, ctx, kTensorStatsIsWeight);
|
||||
}
|
||||
|
||||
// Decode: generates one continuation token for each query in `qbatch`.
|
||||
static void GenerateT(const ModelConfig& config,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const AesCtrEngine& engine, const WeightsPtrs& weights,
|
||||
Activations& activations, QBatch& qbatch, MatMulEnv& env,
|
||||
TimingInfo& timing_info) {
|
||||
for (const LayerWeightsPtrs& layer : weights.c_layers) {
|
||||
SetWeightStats(layer, activations, env.ctx);
|
||||
}
|
||||
|
||||
const size_t max_gen_steps = PrefillTBatchOrQBatch(
|
||||
config, runtime_config, weights, activations, qbatch, env, timing_info);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/tensor_info.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
|
|
|||
|
|
@ -1,3 +1,18 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INFO_H_
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,205 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "gemma/tensor_stats.h"
|
||||
|
||||
#if GCPP_TENSOR_STATS
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <cmath>
|
||||
#include <memory>
|
||||
|
||||
#include "io/io.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "util/zones.h"
|
||||
#include "hwy/profiler.h" // StringTable
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE "gemma/tensor_stats.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
float Correlation(const float* x, size_t num) {
|
||||
double sum = 0.0;
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
sum += x[i];
|
||||
}
|
||||
const double mean = sum / static_cast<double>(num);
|
||||
|
||||
double numerator = 0.0;
|
||||
double sum_sq_current = 0.0;
|
||||
double sum_sq_next = 0.0;
|
||||
|
||||
for (size_t i = 0; i < num - 1; ++i) {
|
||||
const double diff_current = static_cast<double>(x[i]) - mean;
|
||||
const double diff_next = static_cast<double>(x[i + 1]) - mean;
|
||||
|
||||
numerator += diff_current * diff_next;
|
||||
sum_sq_current += diff_current * diff_current;
|
||||
sum_sq_next += diff_next * diff_next;
|
||||
}
|
||||
|
||||
if (sum_sq_current == 0.0 || sum_sq_next == 0.0) return 0.0f;
|
||||
const double denominator = std::sqrt(sum_sq_current * sum_sq_next);
|
||||
const float corr = static_cast<float>(numerator / denominator);
|
||||
HWY_DASSERT(-1.0f <= corr && corr <= 1.0f);
|
||||
return corr;
|
||||
}
|
||||
|
||||
// Only write tensor data the first time it is encountered per layer. This is
|
||||
// a concurrent string+layer -> flag map which avoids std::mutex (incompatible
|
||||
// with fibers). We use a string table to index into per-layer atomic flags.
|
||||
static bool ShouldWrite(const char* name, size_t layer_idx) {
|
||||
constexpr size_t kMaxNames = 128;
|
||||
constexpr size_t kMaxLayers = 128;
|
||||
HWY_DASSERT(layer_idx < kMaxLayers);
|
||||
static hwy::StringTable<kMaxNames> s_table;
|
||||
const size_t name_idx = s_table.Add(name);
|
||||
static std::atomic_flag flags[kMaxNames * kMaxLayers] = {};
|
||||
return !flags[name_idx * kMaxLayers + layer_idx].test_and_set(
|
||||
std::memory_order_acq_rel);
|
||||
}
|
||||
|
||||
std::unique_ptr<File> MaybeOpenFile(size_t layer_idx, const MatPtr& type_erased,
|
||||
const Path& tensor_output) {
|
||||
if (tensor_output.Empty()) return nullptr;
|
||||
if (!ShouldWrite(type_erased.Name(), layer_idx)) return nullptr;
|
||||
char path[1024];
|
||||
snprintf(path, sizeof(path), "%s/%s_L%02zu_%zux%zu_%s.bin",
|
||||
tensor_output.path.c_str(), type_erased.Name(), layer_idx,
|
||||
type_erased.Rows(), type_erased.Cols(),
|
||||
TypeName(type_erased.GetType()));
|
||||
return OpenFileOrAbort(Path(path), "wb");
|
||||
}
|
||||
|
||||
void MaybeWriteRow(const std::unique_ptr<File>& file, const MatPtr& type_erased,
|
||||
size_t row_idx) {
|
||||
if (!file) return;
|
||||
const size_t bytes_per_row = type_erased.Cols() * type_erased.ElementBytes();
|
||||
file->Write(type_erased.RowBytes(row_idx), bytes_per_row,
|
||||
bytes_per_row * row_idx);
|
||||
}
|
||||
|
||||
// First dispatch to the type, then parallel over rows, then vectorized
|
||||
// decompress and Notify for each value.
|
||||
void UpdateStatsT(TensorStats& stats, size_t layer_idx,
|
||||
const MatPtr& type_erased, ThreadingContext& ctx, int flags,
|
||||
size_t cluster_idx, Parallelism parallelism) {
|
||||
std::unique_ptr<File> file =
|
||||
MaybeOpenFile(layer_idx, type_erased, ctx.tensor_output);
|
||||
|
||||
if ((flags & kTensorStatsIsWeight) && layer_idx != 0) {
|
||||
// Still compute stats, but remember not to print them.
|
||||
stats.Get(layer_idx, 0).DoNotPrint();
|
||||
}
|
||||
|
||||
CallUpcasted(&type_erased, [&](const auto* mat) {
|
||||
const size_t cols = mat->Cols();
|
||||
|
||||
ParallelFor(
|
||||
parallelism, mat->Rows(), ctx, cluster_idx, Callers::kTensorStats,
|
||||
[&](size_t row_idx, size_t global_idx) {
|
||||
GCPP_ZONE(ctx, global_idx, Zones::kGenStats);
|
||||
|
||||
auto* HWY_RESTRICT row = mat->Row(row_idx);
|
||||
MaybeWriteRow(file, type_erased, row_idx);
|
||||
|
||||
using Packed = hwy::RemoveCvRef<decltype(*row)>;
|
||||
PackedSpan<Packed> packed(const_cast<Packed*>(row), cols);
|
||||
|
||||
TensorStatsAccumulator& my_stats = stats.Get(layer_idx, global_idx);
|
||||
my_stats.NotifyCond(ConditionNumber(row, cols));
|
||||
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
HWY_LANES_CONSTEXPR size_t NF = hn::Lanes(df);
|
||||
HWY_ALIGN float buf[2 * hn::MaxLanes(df)];
|
||||
|
||||
size_t packed_ofs = 0;
|
||||
if (cols >= 2 * NF) {
|
||||
for (; packed_ofs <= cols - 2 * NF; packed_ofs += 2 * NF) {
|
||||
VF v0, v1;
|
||||
Decompress2(df, packed, packed_ofs, v0, v1);
|
||||
hn::Store(v0, df, buf);
|
||||
hn::Store(v1, df, buf + NF);
|
||||
const VF min_mag = hn::Min(hn::Abs(v0), hn::Abs(v1));
|
||||
const VF max_mag = hn::Max(hn::Abs(v0), hn::Abs(v1));
|
||||
const float min = hn::ReduceMin(df, min_mag);
|
||||
if (min != 0.0f) { // Avoid division by zero.
|
||||
my_stats.NotifyGroup(min, hn::ReduceMax(df, max_mag));
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < 2 * NF; ++i) {
|
||||
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
|
||||
}
|
||||
my_stats.NotifyCorr(Correlation(buf, 2 * NF));
|
||||
}
|
||||
}
|
||||
|
||||
// Zero to two vectors remaining.
|
||||
for (; packed_ofs < cols; packed_ofs += NF) {
|
||||
const size_t remaining = HWY_MIN(NF, cols - packed_ofs);
|
||||
DecompressAndZeroPad(df, packed, packed_ofs, buf, remaining);
|
||||
// Skip NotifyGroup for this partial group.
|
||||
for (size_t i = 0; i < remaining; ++i) {
|
||||
my_stats.Notify(buf[i], row_idx, packed_ofs + i);
|
||||
}
|
||||
my_stats.NotifyCorr(Correlation(buf, remaining));
|
||||
}
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
HWY_EXPORT(UpdateStatsT);
|
||||
|
||||
// Must reside in .cc file so that we can #include compress-inl.h.
|
||||
void TensorStats::Notify(size_t layer_idx, const MatPtr& type_erased,
|
||||
ThreadingContext& ctx, int flags, size_t cluster_idx,
|
||||
Parallelism parallelism) {
|
||||
// Ignore empty tensors.
|
||||
if (type_erased.GetType() == Type::kUnknown || type_erased.Cols() == 0) {
|
||||
return;
|
||||
}
|
||||
HWY_DYNAMIC_DISPATCH(UpdateStatsT)(*this, layer_idx, type_erased, ctx, flags,
|
||||
cluster_idx, parallelism);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
||||
#endif // GCPP_TENSOR_STATS
|
||||
|
|
@ -0,0 +1,347 @@
|
|||
// Copyright 2025 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include "util/basics.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
#ifndef GCPP_TENSOR_STATS
|
||||
#define GCPP_TENSOR_STATS 0
|
||||
#endif
|
||||
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
|
||||
#if GCPP_TENSOR_STATS
|
||||
#include <cmath>
|
||||
#include <vector>
|
||||
|
||||
#include "hwy/stats.h"
|
||||
#endif // GCPP_TENSOR_STATS
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// For flags. Used to inhibit printing per-layer stats for weights.
|
||||
HWY_INLINE_VAR constexpr int kTensorStatsIsWeight = 1;
|
||||
|
||||
#if GCPP_TENSOR_STATS
|
||||
|
||||
HWY_INLINE_VAR constexpr size_t kStatsMaxCols = 8192;
|
||||
|
||||
// Separate summary of the per-layer stats, updated by `TensorStatsAccumulator`.
|
||||
// We pass per-layer statistics such as the mean value to `hwy::Stats::Notify``
|
||||
// to see the distribution of per-layer means.
|
||||
struct TensorStatsAcrossLayers {
|
||||
bool IsEmpty() const { return s_frobenius.Count() == 0; }
|
||||
|
||||
void Print() {
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "frob %s\n", s_frobenius.ToString(skip).c_str());
|
||||
fprintf(stderr, "cnd.min %s\n", s_cond_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "cnd.avg %s\n", s_cond_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "cnd.max %s\n", s_cond_max.ToString(skip).c_str());
|
||||
fprintf(stderr, "val.min %s\n", s_val_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "val.avg %s\n", s_val_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "val.krt %s\n", s_val_kurt.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag.min %s\n", s_mag_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag.avg %s\n", s_mag_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag.max %s\n", s_mag_max.ToString(skip).c_str());
|
||||
if (hwy::ScalarAbs(s_corr_avg.Max()) > 0.05f) {
|
||||
fprintf(stderr, "cor.avg %s\n", s_corr_avg.ToString(skip).c_str());
|
||||
}
|
||||
fprintf(stderr, "cor.max %s\n", s_corr_max.ToString(skip).c_str());
|
||||
fprintf(stderr, "rng_avg %s\n", s_range_avg.ToString(skip).c_str());
|
||||
fprintf(stderr, "exp.min %s\n", s_exp_min.ToString(skip).c_str());
|
||||
fprintf(stderr, "exp.max %s\n", s_exp_max.ToString(skip).c_str());
|
||||
fprintf(stderr, "exp.mod %s\n", s_exp_mode.ToString(skip).c_str());
|
||||
if (s_exp_subnormal.Min() != 0.0f) {
|
||||
fprintf(stderr, "exp.sub %s\n", s_exp_subnormal.ToString(skip).c_str());
|
||||
}
|
||||
if (s_big_cols.Count() != 0) {
|
||||
fprintf(stderr, "bigCols %s\n", s_big_cols.ToString(skip).c_str());
|
||||
const size_t modal_col = b_big_cols.ModalBinIdx();
|
||||
const size_t num_outlier_cols = b_big_cols.NumNonzero();
|
||||
if (num_outlier_cols > 256) {
|
||||
fprintf(stderr, "bigCols: all up to %zu (max at %zu: %u layers):\n",
|
||||
b_big_cols.LastNonzero(), modal_col, b_big_cols.Bin(modal_col));
|
||||
} else {
|
||||
fprintf(stderr, "bigCols (max at %zu: %u layers):\n", modal_col,
|
||||
b_big_cols.Bin(modal_col));
|
||||
for (size_t i = 0; i < kStatsMaxCols; ++i) {
|
||||
if (b_big_cols.Bin(i) > 2) {
|
||||
fprintf(stderr, " %3zu: %u\n", i, b_big_cols.Bin(i));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
hwy::Stats s_frobenius;
|
||||
|
||||
hwy::Stats s_cond_min;
|
||||
hwy::Stats s_cond_avg;
|
||||
hwy::Stats s_cond_max;
|
||||
|
||||
hwy::Stats s_val_min;
|
||||
hwy::Stats s_val_avg;
|
||||
hwy::Stats s_val_kurt;
|
||||
|
||||
hwy::Stats s_mag_min;
|
||||
hwy::Stats s_mag_avg;
|
||||
hwy::Stats s_mag_max;
|
||||
|
||||
hwy::Stats s_corr_avg;
|
||||
hwy::Stats s_corr_max;
|
||||
|
||||
hwy::Stats s_range_avg;
|
||||
|
||||
hwy::Stats s_exp_min;
|
||||
hwy::Stats s_exp_max;
|
||||
hwy::Stats s_exp_mode;
|
||||
hwy::Stats s_exp_subnormal;
|
||||
|
||||
hwy::Stats s_big_cols; // total number of outlier cols
|
||||
hwy::Bins<kStatsMaxCols> b_big_cols; // # layers with outlier per col
|
||||
};
|
||||
|
||||
// Per-thread and layer.
|
||||
class TensorStatsAccumulator {
|
||||
public:
|
||||
void Notify(float val, size_t row_idx, size_t col_idx) {
|
||||
const double dval = static_cast<double>(val);
|
||||
sum_sq_ += dval * dval;
|
||||
|
||||
s_val_.Notify(val);
|
||||
const float mag = hwy::ScalarAbs(val);
|
||||
|
||||
if (HWY_UNLIKELY(mag >= 64.0f)) {
|
||||
if (row_idx < kMaxBatchSize) b_big_row_.Notify(row_idx);
|
||||
if (col_idx < kStatsMaxCols) b_big_col_.Notify(col_idx);
|
||||
}
|
||||
|
||||
// Skip zero so we can see the lowest actual magnitude
|
||||
if (mag != 0.0f && mag != -0.0f) s_mag_.Notify(mag);
|
||||
|
||||
const uint32_t binary32 = hwy::BitCastScalar<uint32_t>(mag);
|
||||
// Use biased exponent because Bins wants unsigned values.
|
||||
const uint32_t biased_exp = binary32 >> 23;
|
||||
HWY_DASSERT(biased_exp < 256); // already cleared sign bit
|
||||
b_exp256_.Notify(biased_exp);
|
||||
}
|
||||
|
||||
void DoNotPrint() { skip_.fetch_or(1); }
|
||||
bool ShouldPrint() const { return skip_.load() == 0; }
|
||||
|
||||
// Vector code computed the min/max of a group (= two vectors); this is
|
||||
// faster than doing it in `Notify`.
|
||||
void NotifyGroup(float min, float max) {
|
||||
s_group_min_.Notify(min);
|
||||
s_group_max_.Notify(max);
|
||||
// Caller ensures min != 0.
|
||||
s_group_range_.Notify(max / min);
|
||||
}
|
||||
|
||||
void NotifyCorr(float corr) { s_corr_.Notify(corr); }
|
||||
void NotifyCond(double cond) { s_cond_.Notify(cond); }
|
||||
|
||||
void Assimilate(const TensorStatsAccumulator& other) {
|
||||
skip_.fetch_or(other.skip_.load());
|
||||
|
||||
sum_sq_ += other.sum_sq_;
|
||||
b_exp256_.Assimilate(other.b_exp256_);
|
||||
b_big_row_.Assimilate(other.b_big_row_);
|
||||
b_big_col_.Assimilate(other.b_big_col_);
|
||||
s_val_.Assimilate(other.s_val_);
|
||||
s_mag_.Assimilate(other.s_mag_);
|
||||
s_corr_.Assimilate(other.s_corr_);
|
||||
s_group_min_.Assimilate(other.s_group_min_);
|
||||
s_group_max_.Assimilate(other.s_group_max_);
|
||||
s_group_range_.Assimilate(other.s_group_range_);
|
||||
}
|
||||
|
||||
// Called on the per-layer representative after reducing across threads.
|
||||
void NotifyAcrossLayer(TensorStatsAcrossLayers& s) {
|
||||
s.s_frobenius.Notify(std::sqrt(sum_sq_));
|
||||
|
||||
s.s_cond_min.Notify(s_cond_.Min());
|
||||
s.s_cond_avg.Notify(s_cond_.Mean());
|
||||
s.s_cond_max.Notify(s_cond_.Max());
|
||||
|
||||
s.s_val_min.Notify(s_val_.Min());
|
||||
s.s_val_avg.Notify(s_val_.Mean());
|
||||
s.s_val_kurt.Notify(s_val_.Kurtosis());
|
||||
|
||||
s.s_mag_min.Notify(s_mag_.Min());
|
||||
s.s_mag_avg.Notify(s_mag_.Mean());
|
||||
s.s_mag_max.Notify(s_mag_.Max());
|
||||
|
||||
s.s_corr_avg.Notify(s_corr_.Mean());
|
||||
s.s_corr_max.Notify(s_corr_.Max());
|
||||
|
||||
s.s_range_avg.Notify(s_group_range_.Mean());
|
||||
|
||||
const uint32_t subnormals = b_exp256_.Bin(0);
|
||||
// Prevent subnormals from hiding the min exponent.
|
||||
b_exp256_.ResetBin(0);
|
||||
s.s_exp_min.Notify(b_exp256_.FirstNonzero());
|
||||
s.s_exp_max.Notify(b_exp256_.LastNonzero());
|
||||
s.s_exp_mode.Notify(b_exp256_.ModalBinIdx());
|
||||
s.s_exp_subnormal.Notify(subnormals);
|
||||
|
||||
const uint32_t num_outliers = b_big_col_.NumNonzero();
|
||||
if (num_outliers != 0) {
|
||||
s.s_big_cols.Notify(num_outliers);
|
||||
// For each col, count the number of layers that have an outlier there.
|
||||
for (size_t i = 0; i < kStatsMaxCols; ++i) {
|
||||
if (b_big_col_.Bin(i) != 0) s.b_big_cols.Notify(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool IsEmpty() const { return s_val_.Count() == 0; }
|
||||
|
||||
void PrintAll() {
|
||||
fprintf(stderr, "Frob %.2E\n", std::sqrt(sum_sq_));
|
||||
const int skip = hwy::Stats::kNoGeomean;
|
||||
fprintf(stderr, "cnd %s\n", s_cond_.ToString(skip).c_str());
|
||||
fprintf(stderr, "val %s\n", s_val_.ToString(skip).c_str());
|
||||
fprintf(stderr, "mag %s\n", s_mag_.ToString(skip).c_str());
|
||||
fprintf(stderr, "corr %s\n", s_corr_.ToString(skip).c_str());
|
||||
fprintf(stderr, "group_min %s\n", s_group_min_.ToString(skip).c_str());
|
||||
fprintf(stderr, "group_max %s\n", s_group_max_.ToString(skip).c_str());
|
||||
fprintf(stderr, "group_range %s\n", s_group_range_.ToString(skip).c_str());
|
||||
b_exp256_.Print("exp");
|
||||
PrintBinRanges(b_big_row_, "big row");
|
||||
PrintBinRanges(b_big_col_, "big col");
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
private:
|
||||
template <size_t N>
|
||||
void PrintBinRanges(const hwy::Bins<N>& b, const char* name) {
|
||||
uint64_t total = 0;
|
||||
for (size_t i = 0; i < N; ++i) {
|
||||
total += b.Bin(i);
|
||||
}
|
||||
if (total == 0) return;
|
||||
|
||||
// If all bins are at least 10% of a uniform distribution, print the range
|
||||
// to vastly reduce the log size.
|
||||
const size_t min = HWY_MAX(1, total / (N * 10));
|
||||
size_t last = 0;
|
||||
for (; last < N; ++last) {
|
||||
if (b.Bin(last) < min) break;
|
||||
}
|
||||
if (last >= N / 2) {
|
||||
// Also require all subsequent bins to be zero, otherwise we should
|
||||
// print the outlier bins.
|
||||
bool all_zero = true;
|
||||
for (size_t i = last + 1; i < N; ++i) {
|
||||
if (b.Bin(last) != 0) {
|
||||
all_zero = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (all_zero) {
|
||||
fprintf(stderr, "%s: uniform up to %zu\n", name, last);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
b.Print(name, /*skip_zero=*/true);
|
||||
}
|
||||
|
||||
double sum_sq_ = 0.0; // for Frobenius norm
|
||||
hwy::Bins<256> b_exp256_; // exponent
|
||||
hwy::Bins<kMaxBatchSize> b_big_row_;
|
||||
hwy::Bins<kStatsMaxCols> b_big_col_;
|
||||
hwy::Stats s_val_;
|
||||
hwy::Stats s_mag_;
|
||||
hwy::Stats s_cond_; // condition number
|
||||
hwy::Stats s_corr_; // lag-1 autocorrelation
|
||||
hwy::Stats s_group_min_;
|
||||
hwy::Stats s_group_max_;
|
||||
hwy::Stats s_group_range_;
|
||||
std::atomic<int> skip_{0};
|
||||
};
|
||||
|
||||
class TensorStats {
|
||||
public:
|
||||
TensorStats(size_t num_layers, size_t max_workers)
|
||||
: num_layers_(num_layers),
|
||||
max_workers_(max_workers),
|
||||
acc_(num_layers * max_workers) {}
|
||||
|
||||
// Parallelized across rows. If `ctx.tensor_output` is not empty, writes
|
||||
// tensor data to disk for offline analysis, once per tensor and layer.
|
||||
void Notify(size_t layer_idx, const MatPtr& type_erased,
|
||||
ThreadingContext& ctx, int flags = 0, size_t cluster_idx = 0,
|
||||
Parallelism parallelism = Parallelism::kFlat);
|
||||
|
||||
// For use by `UpdateStatsT`.
|
||||
TensorStatsAccumulator& Get(size_t layer_idx, size_t global_idx) {
|
||||
const size_t idx = layer_idx * max_workers_ + global_idx;
|
||||
HWY_DASSERT(idx < acc_.size());
|
||||
return acc_[idx];
|
||||
}
|
||||
|
||||
void ReduceAndPrint(const char* prefix) {
|
||||
for (size_t layer_idx = 0; layer_idx < num_layers_; ++layer_idx) {
|
||||
TensorStatsAccumulator& per_layer = Get(layer_idx, 0);
|
||||
for (size_t global_idx = 1; global_idx < max_workers_; ++global_idx) {
|
||||
per_layer.Assimilate(Get(layer_idx, global_idx));
|
||||
}
|
||||
if (per_layer.IsEmpty()) continue;
|
||||
per_layer.NotifyAcrossLayer(across_layers_);
|
||||
if (per_layer.ShouldPrint()) {
|
||||
fprintf(stderr, "-------------------- %s %zu\n", prefix, layer_idx);
|
||||
per_layer.PrintAll();
|
||||
}
|
||||
}
|
||||
|
||||
if (!across_layers_.IsEmpty()) {
|
||||
fprintf(stderr, "================= across layers %s\n", prefix);
|
||||
across_layers_.Print();
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_layers_;
|
||||
size_t max_workers_;
|
||||
std::vector<TensorStatsAccumulator> acc_;
|
||||
TensorStatsAcrossLayers across_layers_;
|
||||
};
|
||||
|
||||
#else // GCPP_TENSOR_STATS
|
||||
|
||||
class TensorStats {
|
||||
public:
|
||||
TensorStats(size_t, size_t) {}
|
||||
void Notify(size_t, const MatPtr&, ThreadingContext&, int = 0, size_t = 0,
|
||||
Parallelism = Parallelism::kFlat) {}
|
||||
void ReduceAndPrint(const char*) {}
|
||||
};
|
||||
|
||||
#endif // GCPP_TENSOR_STATS
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_STATS_H_
|
||||
|
|
@ -96,7 +96,8 @@ struct LayerWeightsPtrs {
|
|||
// other values for purposes of the KV cache.
|
||||
LayerWeightsPtrs(size_t layer_idx, const LayerConfig& config,
|
||||
const TensorInfoRegistry& tensors)
|
||||
: finder_(LayerSuffix(layer_idx), tensors),
|
||||
: layer_idx(layer_idx),
|
||||
finder_(LayerSuffix(layer_idx), tensors),
|
||||
qkv_einsum_w(finder_("qkv_ein")),
|
||||
qkv_einsum_w1(finder_("qkv1_w")),
|
||||
qkv_einsum_w2(finder_("qkv2_w")),
|
||||
|
|
@ -135,6 +136,7 @@ struct LayerWeightsPtrs {
|
|||
}
|
||||
~LayerWeightsPtrs() = default;
|
||||
|
||||
const size_t layer_idx;
|
||||
const MatFinder finder_;
|
||||
|
||||
// Files either have qkv_einsum_w with 2 stacked matrices or separate
|
||||
|
|
|
|||
4
io/io.cc
4
io/io.cc
|
|
@ -196,9 +196,9 @@ std::unique_ptr<File> OpenFileOrNull(const Path& filename, const char* mode) {
|
|||
namespace gcpp {
|
||||
|
||||
std::unique_ptr<File> OpenFileOrAbort(const Path& filename, const char* mode) {
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, "r");
|
||||
std::unique_ptr<File> file = OpenFileOrNull(filename, mode);
|
||||
if (!file) {
|
||||
HWY_ABORT("Failed to open %s", filename.path.c_str());
|
||||
HWY_ABORT("Failed to open %s, errno %d", filename.path.c_str(), errno);
|
||||
}
|
||||
return file;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -97,7 +97,8 @@ ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
|||
BoundedSlice(args.skip_lps, args.max_lps)),
|
||||
cache_info(topology),
|
||||
allocator(topology, cache_info, args.bind != Tristate::kFalse),
|
||||
pools(topology, allocator, args.max_threads, args.pin) {
|
||||
pools(topology, allocator, args.max_threads, args.pin),
|
||||
tensor_output(args.tensor_output) {
|
||||
PROFILER_ZONE("Startup.ThreadingContext autotune");
|
||||
TunePools(hwy::PoolWaitMode::kSpin, *this);
|
||||
// kBlock is the default, hence set/tune it last.
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
#include <stdint.h>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "io/io.h" // Path
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
|
|
@ -55,6 +56,8 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
Tristate pin; // pin threads?
|
||||
Tristate spin; // use spin waits?
|
||||
|
||||
Path tensor_output; // empty, or directory for tensor output
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
// These can be used to partition CPU packages/sockets and their
|
||||
|
|
@ -85,6 +88,9 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
|
||||
visitor(bind, "bind", Tristate::kDefault,
|
||||
"Bind memory to sockets? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
|
||||
visitor(tensor_output, "tensor_output", Path(),
|
||||
"Empty, or directory for tensor output.", 2);
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -124,6 +130,8 @@ struct ThreadingContext {
|
|||
|
||||
// Per-package/cluster/within cluster pools of threads, matching `topology`.
|
||||
NestedPools pools;
|
||||
|
||||
Path tensor_output; // used by `TensorStats::Notify`.
|
||||
};
|
||||
|
||||
#define GCPP_ZONE(ctx, global_idx, zone_enum) \
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ const char* ZoneName(Zones zone) {
|
|||
return "Gen.SampleTop1";
|
||||
case Zones::kGenSampleTopK:
|
||||
return "Gen.SampleTopK";
|
||||
case Zones::kGenStats:
|
||||
return "Gen.Stats";
|
||||
case Zones::kMMDecompressA:
|
||||
return "MM.DecompressA";
|
||||
case Zones::kMMDispatch:
|
||||
|
|
@ -163,6 +165,8 @@ const char* CallerName(Callers caller) {
|
|||
return "ReadBatches";
|
||||
case Callers::kSampleAndStream:
|
||||
return "SampleAndStream";
|
||||
case Callers::kTensorStats:
|
||||
return "TensorStats";
|
||||
case Callers::kTest: // only for unit tests.
|
||||
return "Test-only!";
|
||||
case Callers::kTunePool:
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ enum class Zones { // Keep sorted
|
|||
kGenFFW,
|
||||
kGenSampleTop1,
|
||||
kGenSampleTopK,
|
||||
kGenStats,
|
||||
kMMDecompressA,
|
||||
kMMDispatch,
|
||||
kMMMatMul,
|
||||
|
|
@ -96,6 +97,7 @@ enum class Callers { // Keep sorted
|
|||
kReadAllToBF16,
|
||||
kReadBatches,
|
||||
kSampleAndStream,
|
||||
kTensorStats,
|
||||
kTest, // only for unit tests.
|
||||
kTunePool,
|
||||
kVitDotSoftmax1,
|
||||
|
|
|
|||
Loading…
Reference in New Issue