mirror of https://github.com/google/gemma.cpp.git
Remove no longer required stats.h - use Highway version instead
PiperOrigin-RevId: 640440379
This commit is contained in:
parent
175e389c3c
commit
5c3e5f7038
|
|
@ -48,23 +48,12 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Deprecated because it is also implemented in Highway; will be removed once
|
|
||||||
# that Highway version is sufficiently widespread.
|
|
||||||
cc_library(
|
|
||||||
name = "stats",
|
|
||||||
srcs = ["stats.cc"],
|
|
||||||
hdrs = ["stats.h"],
|
|
||||||
deps = [
|
|
||||||
"@hwy//:hwy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "distortion",
|
name = "distortion",
|
||||||
hdrs = ["distortion.h"],
|
hdrs = ["distortion.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":stats",
|
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:stats",
|
||||||
"@hwy//hwy/contrib/sort:vqsort",
|
"@hwy//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -88,9 +77,9 @@ cc_library(
|
||||||
hdrs = ["test_util.h"],
|
hdrs = ["test_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":distortion",
|
":distortion",
|
||||||
":stats",
|
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@hwy//:hwy_test_util",
|
||||||
|
"@hwy//:stats",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -169,9 +158,9 @@ cc_library(
|
||||||
":io",
|
":io",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
":stats",
|
|
||||||
"@hwy//:dot",
|
"@hwy//:dot",
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
|
"@hwy//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -184,9 +173,9 @@ cc_library(
|
||||||
":distortion",
|
":distortion",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
":stats",
|
|
||||||
"@hwy//:hwy",
|
"@hwy//:hwy",
|
||||||
"@hwy//:nanobenchmark", # timer
|
"@hwy//:nanobenchmark", # timer
|
||||||
|
"@hwy//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
"@hwy//hwy/contrib/sort:vqsort",
|
"@hwy//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,9 @@
|
||||||
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "compression/nuq.h"
|
#include "compression/nuq.h"
|
||||||
#include "compression/stats.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/stats.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_ANALYZE_H_
|
||||||
|
|
@ -55,7 +55,7 @@ namespace HWY_NAMESPACE {
|
||||||
class PerThread {
|
class PerThread {
|
||||||
public:
|
public:
|
||||||
void NotifyGroup(const float* group) {
|
void NotifyGroup(const float* group) {
|
||||||
Stats s_group;
|
hwy::Stats s_group;
|
||||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||||
// Skip zero so we can see the lowest actual magnitude
|
// Skip zero so we can see the lowest actual magnitude
|
||||||
if (group[i] == 0.0f || group[i] == -0.0f) continue;
|
if (group[i] == 0.0f || group[i] == -0.0f) continue;
|
||||||
|
|
@ -119,7 +119,7 @@ class PerThread {
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrintAll() {
|
void PrintAll() {
|
||||||
const int skip = Stats::kNoGeomean;
|
const int skip = hwy::Stats::kNoGeomean;
|
||||||
fprintf(stderr, "num tiny %zu\n", num_tiny_);
|
fprintf(stderr, "num tiny %zu\n", num_tiny_);
|
||||||
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
|
fprintf(stderr, "weights %s\n", s_all_.ToString(skip).c_str());
|
||||||
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
|
fprintf(stderr, " ranges %s\n", s_group_ranges_.ToString(skip).c_str());
|
||||||
|
|
@ -140,18 +140,18 @@ class PerThread {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
size_t num_tiny_ = 0;
|
size_t num_tiny_ = 0;
|
||||||
Stats s_all_;
|
hwy::Stats s_all_;
|
||||||
Stats s_group_ranges_;
|
hwy::Stats s_group_ranges_;
|
||||||
Stats s_group_mins_;
|
hwy::Stats s_group_mins_;
|
||||||
Stats s_group_maxs_;
|
hwy::Stats s_group_maxs_;
|
||||||
Stats s_group_max_vs_min_;
|
hwy::Stats s_group_max_vs_min_;
|
||||||
Stats s_erange_;
|
hwy::Stats s_erange_;
|
||||||
Stats s_km_1_;
|
hwy::Stats s_km_1_;
|
||||||
Stats s_km_2_;
|
hwy::Stats s_km_2_;
|
||||||
Stats s_cut15_;
|
hwy::Stats s_cut15_;
|
||||||
Bins<100> b_magn100_;
|
hwy::Bins<100> b_magn100_;
|
||||||
Bins<256> b_exp256_;
|
hwy::Bins<256> b_exp256_;
|
||||||
Bins<16> b_m4_;
|
hwy::Bins<16> b_m4_;
|
||||||
uint8_t padding_[64]; // prevent false sharing
|
uint8_t padding_[64]; // prevent false sharing
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -172,11 +172,11 @@ class PerLayer {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const Stats& GetStats() const { return s_layer_; }
|
const hwy::Stats& GetStats() const { return s_layer_; }
|
||||||
size_t Outliers() const { return num_outliers_; }
|
size_t Outliers() const { return num_outliers_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Stats s_layer_;
|
hwy::Stats s_layer_;
|
||||||
size_t num_outliers_ = 0;
|
size_t num_outliers_ = 0;
|
||||||
uint8_t padding[64]; // prevent false sharing
|
uint8_t padding[64]; // prevent false sharing
|
||||||
};
|
};
|
||||||
|
|
@ -207,7 +207,7 @@ static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
||||||
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
|
per_layer[idx_layer].UpdateOutliers(layer, weights_per_layer);
|
||||||
});
|
});
|
||||||
|
|
||||||
const int skip = Stats::kNoGeomean;
|
const int skip = hwy::Stats::kNoGeomean;
|
||||||
fprintf(stderr, "\n------------%s\n", caption);
|
fprintf(stderr, "\n------------%s\n", caption);
|
||||||
|
|
||||||
for (size_t i = 1; i < pool.NumThreads(); ++i) {
|
for (size_t i = 1; i < pool.NumThreads(); ++i) {
|
||||||
|
|
@ -215,8 +215,8 @@ static HWY_NOINLINE void Analyze(const char* caption, float* mat, size_t layers,
|
||||||
}
|
}
|
||||||
tls[0].PrintAll();
|
tls[0].PrintAll();
|
||||||
|
|
||||||
Stats s_layer_ranges;
|
hwy::Stats s_layer_ranges;
|
||||||
Stats s_layer_outliers;
|
hwy::Stats s_layer_outliers;
|
||||||
for (size_t i = 0; i < layers; ++i) {
|
for (size_t i = 0; i < layers; ++i) {
|
||||||
fprintf(stderr, " %02zu %s\n", i,
|
fprintf(stderr, " %02zu %s\n", i,
|
||||||
per_layer[i].GetStats().ToString(skip).c_str());
|
per_layer[i].GetStats().ToString(skip).c_str());
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "hwy/base.h" // hwy::bfloat16_t
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
#include "compression/stats.h"
|
#include "hwy/stats.h"
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -117,7 +117,7 @@ class CompressStats {
|
||||||
}
|
}
|
||||||
|
|
||||||
void PrintAll() {
|
void PrintAll() {
|
||||||
const int skip = Stats::kNoGeomean;
|
const int skip = hwy::Stats::kNoGeomean;
|
||||||
fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str());
|
fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str());
|
||||||
fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str());
|
fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str());
|
||||||
fprintf(stderr, " #exact %.3E\n", static_cast<double>(num_exact_));
|
fprintf(stderr, " #exact %.3E\n", static_cast<double>(num_exact_));
|
||||||
|
|
@ -132,10 +132,10 @@ class CompressStats {
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Stats s_pnorm_;
|
hwy::Stats s_pnorm_;
|
||||||
Stats s_snr_;
|
hwy::Stats s_snr_;
|
||||||
size_t num_exact_ = 0;
|
size_t num_exact_ = 0;
|
||||||
Bins<1000> hist_weights_;
|
hwy::Bins<1000> hist_weights_;
|
||||||
char padding_[64]; // prevent false sharing
|
char padding_[64]; // prevent false sharing
|
||||||
};
|
};
|
||||||
#else
|
#else
|
||||||
|
|
|
||||||
|
|
@ -22,10 +22,10 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/stats.h"
|
|
||||||
#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT
|
#include "hwy/aligned_allocator.h" // HWY_ALIGNMENT
|
||||||
#include "hwy/base.h" // ScalarAbs
|
#include "hwy/base.h" // ScalarAbs
|
||||||
#include "hwy/contrib/sort/vqsort.h"
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
|
#include "hwy/stats.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -198,13 +198,13 @@ class DistortionStats {
|
||||||
return weighted_sum / sum_weights;
|
return weighted_sum / sum_weights;
|
||||||
}
|
}
|
||||||
|
|
||||||
Stats& L1() { return s_l1_; }
|
hwy::Stats& L1() { return s_l1_; }
|
||||||
Stats& Original() { return s_original_; }
|
hwy::Stats& Original() { return s_original_; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
Stats s_original_;
|
hwy::Stats s_original_;
|
||||||
Stats s_l1_;
|
hwy::Stats s_l1_;
|
||||||
Bins<100> b_l1_;
|
hwy::Bins<100> b_l1_;
|
||||||
CascadedSummation<double> sum_l1_; // all
|
CascadedSummation<double> sum_l1_; // all
|
||||||
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
|
CascadedSummation<double> sum_l1_rounded_; // only if rounded_to_zero
|
||||||
std::vector<float> l1_;
|
std::vector<float> l1_;
|
||||||
|
|
|
||||||
|
|
@ -188,7 +188,7 @@ struct TestNormal {
|
||||||
HWY_ASSERT(in);
|
HWY_ASSERT(in);
|
||||||
|
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
Stats in_stats;
|
hwy::Stats in_stats;
|
||||||
for (size_t i = 0; i < kGroupSize; ++i) {
|
for (size_t i = 0; i < kGroupSize; ++i) {
|
||||||
const double r = RandomGaussian(rng);
|
const double r = RandomGaussian(rng);
|
||||||
in_stats.Notify(r);
|
in_stats.Notify(r);
|
||||||
|
|
@ -288,7 +288,7 @@ struct TestStream {
|
||||||
HWY_ASSERT(in && out && nuq);
|
HWY_ASSERT(in && out && nuq);
|
||||||
|
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
Stats in_stats;
|
hwy::Stats in_stats;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
in[i] = static_cast<float>(RandomGaussian(rng));
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
in_stats.Notify(in[i]);
|
in_stats.Notify(in[i]);
|
||||||
|
|
@ -358,7 +358,7 @@ struct TestDot {
|
||||||
|
|
||||||
// Generate inputs and verify their distribution.
|
// Generate inputs and verify their distribution.
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
Stats in_stats;
|
hwy::Stats in_stats;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
in[i] = static_cast<float>(RandomGaussian(rng));
|
in[i] = static_cast<float>(RandomGaussian(rng));
|
||||||
in_stats.Notify(in[i]);
|
in_stats.Notify(in[i]);
|
||||||
|
|
@ -400,7 +400,7 @@ struct TestDot {
|
||||||
float exact = 0.0f; // using original input
|
float exact = 0.0f; // using original input
|
||||||
float expected = 0.0f; // using decoded NUQ
|
float expected = 0.0f; // using decoded NUQ
|
||||||
DistortionStats dec_stats;
|
DistortionStats dec_stats;
|
||||||
Stats ratios;
|
hwy::Stats ratios;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
dec_stats.Notify(in[i], dec[i]);
|
dec_stats.Notify(in[i], dec[i]);
|
||||||
const float v1 = hwy::ConvertScalarTo<float>(vec[i]);
|
const float v1 = hwy::ConvertScalarTo<float>(vec[i]);
|
||||||
|
|
|
||||||
|
|
@ -411,7 +411,7 @@ struct TestDot {
|
||||||
|
|
||||||
// Generate inputs and verify their distribution.
|
// Generate inputs and verify their distribution.
|
||||||
hwy::RandomState rng;
|
hwy::RandomState rng;
|
||||||
Stats in_stats;
|
hwy::Stats in_stats;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
const float r = static_cast<float>(RandomGaussian(rng));
|
const float r = static_cast<float>(RandomGaussian(rng));
|
||||||
in_stats.Notify(r);
|
in_stats.Notify(r);
|
||||||
|
|
@ -477,7 +477,7 @@ struct TestDot {
|
||||||
float exact = 0.0f; // using original input
|
float exact = 0.0f; // using original input
|
||||||
float expected = 0.0f; // using decoded SFP
|
float expected = 0.0f; // using decoded SFP
|
||||||
DistortionStats dec_stats;
|
DistortionStats dec_stats;
|
||||||
Stats ratios;
|
hwy::Stats ratios;
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
const float in1 = hwy::ConvertScalarTo<float>(in[i]);
|
const float in1 = hwy::ConvertScalarTo<float>(in[i]);
|
||||||
const float dec1 = hwy::ConvertScalarTo<float>(dec[i]);
|
const float dec1 = hwy::ConvertScalarTo<float>(dec[i]);
|
||||||
|
|
|
||||||
|
|
@ -1,120 +0,0 @@
|
||||||
// Copyright 2024 Google LLC
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#include "compression/stats.h"
|
|
||||||
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
#include <algorithm> // std::min
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
|
||||||
|
|
||||||
namespace gcpp {
|
|
||||||
|
|
||||||
void Stats::Assimilate(const Stats& other) {
|
|
||||||
const int64_t total_n = n_ + other.n_;
|
|
||||||
if (total_n == 0) return; // Nothing to do; prevents div by zero.
|
|
||||||
|
|
||||||
min_ = std::min(min_, other.min_);
|
|
||||||
max_ = std::max(max_, other.max_);
|
|
||||||
|
|
||||||
product_ *= other.product_;
|
|
||||||
|
|
||||||
const double product_n = n_ * other.n_;
|
|
||||||
const double n2 = n_ * n_;
|
|
||||||
const double other_n2 = other.n_ * other.n_;
|
|
||||||
const int64_t total_n2 = total_n * total_n;
|
|
||||||
const double total_n3 = static_cast<double>(total_n2) * total_n;
|
|
||||||
// Precompute reciprocal for speed - used at least twice.
|
|
||||||
const double inv_total_n = 1.0 / total_n;
|
|
||||||
const double inv_total_n2 = 1.0 / total_n2;
|
|
||||||
|
|
||||||
const double delta = other.m1_ - m1_;
|
|
||||||
const double delta2 = delta * delta;
|
|
||||||
const double delta3 = delta * delta2;
|
|
||||||
const double delta4 = delta2 * delta2;
|
|
||||||
|
|
||||||
m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n;
|
|
||||||
|
|
||||||
const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n;
|
|
||||||
|
|
||||||
const double new_m3 =
|
|
||||||
m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 +
|
|
||||||
3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n;
|
|
||||||
|
|
||||||
m4_ += other.m4_ +
|
|
||||||
delta4 * product_n * (n2 - product_n + other_n2) / total_n3 +
|
|
||||||
6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 +
|
|
||||||
4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n;
|
|
||||||
|
|
||||||
m2_ = new_m2;
|
|
||||||
m3_ = new_m3;
|
|
||||||
n_ = total_n;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::string Stats::ToString(int exclude) const {
|
|
||||||
if (Count() == 0) return std::string("(none)");
|
|
||||||
|
|
||||||
char buf[300];
|
|
||||||
int pos = 0;
|
|
||||||
int ret; // snprintf - bytes written or negative for error.
|
|
||||||
|
|
||||||
if ((exclude & kNoCount) == 0) {
|
|
||||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ",
|
|
||||||
static_cast<size_t>(Count()));
|
|
||||||
HWY_ASSERT(ret > 0);
|
|
||||||
pos += ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((exclude & kNoMeanSD) == 0) {
|
|
||||||
const float sd = StandardDeviation();
|
|
||||||
if (sd > 100) {
|
|
||||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.2E SD=%7.1E ",
|
|
||||||
Mean(), sd);
|
|
||||||
} else {
|
|
||||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%8.6f SD=%7.5f ",
|
|
||||||
Mean(), sd);
|
|
||||||
}
|
|
||||||
HWY_ASSERT(ret > 0);
|
|
||||||
pos += ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((exclude & kNoMinMax) == 0) {
|
|
||||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%8.5e Max=%8.5e ", Min(),
|
|
||||||
Max());
|
|
||||||
HWY_ASSERT(ret > 0);
|
|
||||||
pos += ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((exclude & kNoSkewKurt) == 0) {
|
|
||||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ",
|
|
||||||
Skewness(), Kurtosis());
|
|
||||||
HWY_ASSERT(ret > 0);
|
|
||||||
pos += ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((exclude & kNoGeomean) == 0) {
|
|
||||||
ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ",
|
|
||||||
GeometricMean());
|
|
||||||
HWY_ASSERT(ret > 0);
|
|
||||||
pos += ret;
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_ASSERT(pos < static_cast<int>(sizeof(buf)));
|
|
||||||
return buf;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
|
||||||
|
|
@ -1,193 +0,0 @@
|
||||||
// Copyright 2024 Google LLC
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
|
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
|
|
||||||
|
|
||||||
#include <stdint.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <string>
|
|
||||||
|
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
|
||||||
|
|
||||||
namespace gcpp {
|
|
||||||
|
|
||||||
// Thread-compatible.
|
|
||||||
template <size_t N>
|
|
||||||
class Bins {
|
|
||||||
public:
|
|
||||||
Bins() { Reset(); }
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void Notify(T bin) {
|
|
||||||
HWY_ASSERT(T{0} <= bin && bin < static_cast<T>(N));
|
|
||||||
counts_[static_cast<int32_t>(bin)]++;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Assimilate(const Bins<N>& other) {
|
|
||||||
for (size_t i = 0; i < N; ++i) {
|
|
||||||
counts_[i] += other.counts_[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Print(const char* caption) const {
|
|
||||||
fprintf(stderr, "\n%s [%zu]\n", caption, N);
|
|
||||||
size_t last_nonzero = 0;
|
|
||||||
for (size_t i = N - 1; i < N; --i) {
|
|
||||||
if (counts_[i] != 0) {
|
|
||||||
last_nonzero = i;
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (size_t i = 0; i <= last_nonzero; ++i) {
|
|
||||||
fprintf(stderr, " %zu\n", counts_[i]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
void Reset() {
|
|
||||||
for (size_t i = 0; i < N; ++i) {
|
|
||||||
counts_[i] = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
size_t counts_[N];
|
|
||||||
};
|
|
||||||
|
|
||||||
// Descriptive statistics of a variable (4 moments). Thread-compatible.
|
|
||||||
class Stats {
|
|
||||||
public:
|
|
||||||
Stats() { Reset(); }
|
|
||||||
|
|
||||||
void Notify(const float x) {
|
|
||||||
++n_;
|
|
||||||
|
|
||||||
min_ = HWY_MIN(min_, x);
|
|
||||||
max_ = HWY_MAX(max_, x);
|
|
||||||
|
|
||||||
product_ *= x;
|
|
||||||
|
|
||||||
// Online moments. Reference: https://goo.gl/9ha694
|
|
||||||
const double d = x - m1_;
|
|
||||||
const double d_div_n = d / n_;
|
|
||||||
const double d2n1_div_n = d * (n_ - 1) * d_div_n;
|
|
||||||
const int64_t n_poly = n_ * n_ - 3 * n_ + 3;
|
|
||||||
m1_ += d_div_n;
|
|
||||||
m4_ += d_div_n * (d_div_n * (d2n1_div_n * n_poly + 6.0 * m2_) - 4.0 * m3_);
|
|
||||||
m3_ += d_div_n * (d2n1_div_n * (n_ - 2) - 3.0 * m2_);
|
|
||||||
m2_ += d2n1_div_n;
|
|
||||||
}
|
|
||||||
|
|
||||||
void Assimilate(const Stats& other);
|
|
||||||
|
|
||||||
int64_t Count() const { return n_; }
|
|
||||||
|
|
||||||
float Min() const { return min_; }
|
|
||||||
float Max() const { return max_; }
|
|
||||||
|
|
||||||
double GeometricMean() const {
|
|
||||||
return n_ == 0 ? 0.0 : pow(product_, 1.0 / n_);
|
|
||||||
}
|
|
||||||
|
|
||||||
double Mean() const { return m1_; }
|
|
||||||
// Same as Mu2. Assumes n_ is large.
|
|
||||||
double SampleVariance() const {
|
|
||||||
return n_ == 0 ? 0.0 : m2_ / static_cast<int>(n_);
|
|
||||||
}
|
|
||||||
// Unbiased estimator for population variance even for smaller n_.
|
|
||||||
double Variance() const {
|
|
||||||
if (n_ == 0) return 0.0;
|
|
||||||
if (n_ == 1) return m2_;
|
|
||||||
return m2_ / static_cast<int>(n_ - 1);
|
|
||||||
}
|
|
||||||
double StandardDeviation() const { return std::sqrt(Variance()); }
|
|
||||||
// Near zero for normal distributions; if positive on a unimodal distribution,
|
|
||||||
// the right tail is fatter. Assumes n_ is large.
|
|
||||||
double SampleSkewness() const {
|
|
||||||
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
|
|
||||||
return m3_ * std::sqrt(static_cast<double>(n_)) / std::pow(m2_, 1.5);
|
|
||||||
}
|
|
||||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
|
||||||
double Skewness() const {
|
|
||||||
if (n_ == 0) return 0.0;
|
|
||||||
const double biased = SampleSkewness();
|
|
||||||
const double r = (n_ - 1.0) / n_;
|
|
||||||
return biased * std::pow(r, 1.5);
|
|
||||||
}
|
|
||||||
// Near zero for normal distributions; smaller values indicate fewer/smaller
|
|
||||||
// outliers and larger indicates more/larger outliers. Assumes n_ is large.
|
|
||||||
double SampleKurtosis() const {
|
|
||||||
if (hwy::ScalarAbs(m2_) < 1E-7) return 0.0;
|
|
||||||
return m4_ * n_ / (m2_ * m2_);
|
|
||||||
}
|
|
||||||
// Corrected for bias (same as Wikipedia and Minitab but not Excel).
|
|
||||||
double Kurtosis() const {
|
|
||||||
if (n_ == 0) return 0.0;
|
|
||||||
const double biased = SampleKurtosis();
|
|
||||||
const double r = (n_ - 1.0) / n_;
|
|
||||||
return biased * r * r;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Central moments, useful for "method of moments"-based parameter estimation
|
|
||||||
// of a mixture of two Gaussians. Assumes Count() != 0.
|
|
||||||
double Mu1() const { return m1_; }
|
|
||||||
double Mu2() const { return m2_ / static_cast<int>(n_); }
|
|
||||||
double Mu3() const { return m3_ / static_cast<int>(n_); }
|
|
||||||
double Mu4() const { return m4_ / static_cast<int>(n_); }
|
|
||||||
|
|
||||||
// Which statistics to EXCLUDE in ToString
|
|
||||||
enum {
|
|
||||||
kNoCount = 1,
|
|
||||||
kNoMeanSD = 2,
|
|
||||||
kNoMinMax = 4,
|
|
||||||
kNoSkewKurt = 8,
|
|
||||||
kNoGeomean = 16
|
|
||||||
};
|
|
||||||
std::string ToString(int exclude = 0) const;
|
|
||||||
|
|
||||||
void Reset() {
|
|
||||||
n_ = 0;
|
|
||||||
|
|
||||||
min_ = hwy::HighestValue<float>();
|
|
||||||
max_ = hwy::LowestValue<float>();
|
|
||||||
|
|
||||||
product_ = 1.0;
|
|
||||||
|
|
||||||
m1_ = 0.0;
|
|
||||||
m2_ = 0.0;
|
|
||||||
m3_ = 0.0;
|
|
||||||
m4_ = 0.0;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
int64_t n_; // signed for faster conversion + safe subtraction
|
|
||||||
|
|
||||||
float min_;
|
|
||||||
float max_;
|
|
||||||
|
|
||||||
double product_; // for geomean
|
|
||||||
|
|
||||||
// Moments
|
|
||||||
double m1_;
|
|
||||||
double m2_;
|
|
||||||
double m3_;
|
|
||||||
double m4_;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace gcpp
|
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_STATS_H_
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "compression/stats.h"
|
#include "hwy/stats.h"
|
||||||
#include "hwy/tests/test_util.h" // RandomState
|
#include "hwy/tests/test_util.h" // RandomState
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
|
@ -60,7 +60,7 @@ static inline bool IsNear(T expected, T val, T epsilon = T{1E-6}) {
|
||||||
return IsInside(expected - epsilon, expected + epsilon, val);
|
return IsInside(expected - epsilon, expected + epsilon, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
HWY_INLINE void VerifyGaussian(Stats& stats) {
|
HWY_INLINE void VerifyGaussian(hwy::Stats& stats) {
|
||||||
// Inputs are roughly [-1, 1] and symmetric about zero.
|
// Inputs are roughly [-1, 1] and symmetric about zero.
|
||||||
HWY_ASSERT(IsNear(-1.0f, stats.Min(), 0.10f));
|
HWY_ASSERT(IsNear(-1.0f, stats.Min(), 0.10f));
|
||||||
HWY_ASSERT(IsNear(+1.0f, stats.Max(), 0.10f));
|
HWY_ASSERT(IsNear(+1.0f, stats.Max(), 0.10f));
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,8 @@
|
||||||
|
|
||||||
// Lightweight C++ implementation of the gemma model.
|
// Lightweight C++ implementation of the gemma model.
|
||||||
|
|
||||||
|
#include "gemma/common.h"
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
// which we pass the filename via macro 'argument'.
|
// which we pass the filename via macro 'argument'.
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -42,9 +44,7 @@
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cctype>
|
#include <cctype>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
#include <iostream>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <random>
|
|
||||||
#include <regex> // NOLINT
|
#include <regex> // NOLINT
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
|
|
@ -1410,7 +1410,8 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
switch (model_type) {
|
switch (model_type) {
|
||||||
case Model::GEMMA_2B:
|
case Model::GEMMA_2B:
|
||||||
impl_.reset(CreateGemmaImpl<ConfigGemma2B>(tokenizer_path, weights, pool));
|
impl_.reset(
|
||||||
|
CreateGemmaImpl<ConfigGemma2B>(tokenizer_path, weights, pool));
|
||||||
break;
|
break;
|
||||||
case Model::GEMMA_7B:
|
case Model::GEMMA_7B:
|
||||||
impl_.reset(CreateGemmaImpl<ConfigGemma7B>(tokenizer_path, weights, pool));
|
impl_.reset(CreateGemmaImpl<ConfigGemma7B>(tokenizer_path, weights, pool));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue