mirror of https://github.com/google/gemma.cpp.git
115 lines
3.1 KiB
C++
115 lines
3.1 KiB
C++
// Copyright 2023 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// Target-independent definitions.
|
|
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|
|
|
|
#define COMPRESS_STATS 0
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#if COMPRESS_STATS
|
|
#include <stdio.h>
|
|
#endif
|
|
|
|
#include <memory>
|
|
#include <vector>
|
|
|
|
#include "compression/shared.h" // IWYU pragma: export
|
|
#if COMPRESS_STATS
|
|
#include "compression/distortion.h"
|
|
#include "hwy/stats.h"
|
|
#endif
|
|
|
|
namespace gcpp {
|
|
|
|
#if COMPRESS_STATS
|
|
class CompressStats {
|
|
public:
|
|
void Notify(const DistortionStats& stats) {
|
|
const float pnorm = stats.PNorm();
|
|
const float snr = stats.GeomeanValueDivL1();
|
|
num_exact_ += stats.NumExact();
|
|
s_pnorm_.Notify(pnorm);
|
|
// No loss - skip to avoid dragging down the average.
|
|
if (snr != 0.0f) {
|
|
s_snr_.Notify(snr);
|
|
}
|
|
}
|
|
|
|
void NotifyIn(int sfp) { hist_weights_.Notify(sfp); }
|
|
|
|
void Assimilate(const CompressStats& other) {
|
|
s_pnorm_.Assimilate(other.s_pnorm_);
|
|
s_snr_.Assimilate(other.s_snr_);
|
|
num_exact_ += other.num_exact_;
|
|
hist_weights_.Assimilate(other.hist_weights_);
|
|
}
|
|
|
|
void PrintAll() {
|
|
const int skip = hwy::Stats::kNoGeomean;
|
|
fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str());
|
|
fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str());
|
|
fprintf(stderr, " #exact %.3E\n", static_cast<double>(num_exact_));
|
|
// hist_weights_.Print("indices");
|
|
}
|
|
|
|
void Reset() {
|
|
s_pnorm_.Reset();
|
|
s_snr_.Reset();
|
|
num_exact_ = 0;
|
|
hist_weights_.Reset();
|
|
}
|
|
|
|
private:
|
|
hwy::Stats s_pnorm_;
|
|
hwy::Stats s_snr_;
|
|
size_t num_exact_ = 0;
|
|
hwy::Bins<1000> hist_weights_;
|
|
char padding_[64]; // prevent false sharing
|
|
};
|
|
#else
|
|
class DistortionStats;
|
|
|
|
struct CompressStats {
|
|
void Notify(const DistortionStats&) {}
|
|
void NotifyIn(int) {}
|
|
void Assimilate(const CompressStats&) {}
|
|
void PrintAll() {}
|
|
void Reset() {}
|
|
};
|
|
#endif // COMPRESS_STATS
|
|
|
|
struct CompressPerThread {
|
|
// Allocated the first time NUQ is used.
|
|
std::unique_ptr<NuqStream::ClusterBuf> buf;
|
|
CompressStats stats;
|
|
};
|
|
|
|
struct CompressWorkingSet {
|
|
std::vector<CompressPerThread> tls;
|
|
};
|
|
|
|
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
|
|
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
|
|
// multiplier with which to restore the original values. This is only necessary
|
|
// before compressing to `SfpStream` and `NuqStream`.
|
|
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
|
|
|
|
} // namespace gcpp
|
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
|