gemma.cpp/compression/compress.h

115 lines
3.1 KiB
C++

// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Target-independent definitions.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_
#define COMPRESS_STATS 0
#include <stddef.h>
#include <stdint.h>
#if COMPRESS_STATS
#include <stdio.h>
#endif
#include <memory>
#include <vector>
#include "compression/shared.h" // IWYU pragma: export
#if COMPRESS_STATS
#include "compression/distortion.h"
#include "hwy/stats.h"
#endif
namespace gcpp {
#if COMPRESS_STATS
class CompressStats {
public:
void Notify(const DistortionStats& stats) {
const float pnorm = stats.PNorm();
const float snr = stats.GeomeanValueDivL1();
num_exact_ += stats.NumExact();
s_pnorm_.Notify(pnorm);
// No loss - skip to avoid dragging down the average.
if (snr != 0.0f) {
s_snr_.Notify(snr);
}
}
void NotifyIn(int sfp) { hist_weights_.Notify(sfp); }
void Assimilate(const CompressStats& other) {
s_pnorm_.Assimilate(other.s_pnorm_);
s_snr_.Assimilate(other.s_snr_);
num_exact_ += other.num_exact_;
hist_weights_.Assimilate(other.hist_weights_);
}
void PrintAll() {
const int skip = hwy::Stats::kNoGeomean;
fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str());
fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str());
fprintf(stderr, " #exact %.3E\n", static_cast<double>(num_exact_));
// hist_weights_.Print("indices");
}
void Reset() {
s_pnorm_.Reset();
s_snr_.Reset();
num_exact_ = 0;
hist_weights_.Reset();
}
private:
hwy::Stats s_pnorm_;
hwy::Stats s_snr_;
size_t num_exact_ = 0;
hwy::Bins<1000> hist_weights_;
char padding_[64]; // prevent false sharing
};
#else
class DistortionStats;
struct CompressStats {
void Notify(const DistortionStats&) {}
void NotifyIn(int) {}
void Assimilate(const CompressStats&) {}
void PrintAll() {}
void Reset() {}
};
#endif // COMPRESS_STATS
struct CompressPerThread {
// Allocated the first time NUQ is used.
std::unique_ptr<NuqStream::ClusterBuf> buf;
CompressStats stats;
};
struct CompressWorkingSet {
std::vector<CompressPerThread> tls;
};
// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales
// them such that the largest magnitude is `SfpStream::kMax`, and returns the
// multiplier with which to restore the original values. This is only necessary
// before compressing to `SfpStream` and `NuqStream`.
float ScaleWeights(float* HWY_RESTRICT raw, size_t num);
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_