// Copyright 2023 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Target-independent definitions. #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define COMPRESS_STATS 0 #include #include #if COMPRESS_STATS #include #endif #include #include #include "compression/types.h" // IWYU pragma: export #if COMPRESS_STATS #include "compression/distortion.h" #include "hwy/stats.h" #endif namespace gcpp { #if COMPRESS_STATS class CompressStats { public: void Notify(const DistortionStats& stats) { const float pnorm = stats.PNorm(); const float snr = stats.GeomeanValueDivL1(); num_exact_ += stats.NumExact(); s_pnorm_.Notify(pnorm); // No loss - skip to avoid dragging down the average. if (snr != 0.0f) { s_snr_.Notify(snr); } } void NotifyIn(int sfp) { hist_weights_.Notify(sfp); } void Assimilate(const CompressStats& other) { s_pnorm_.Assimilate(other.s_pnorm_); s_snr_.Assimilate(other.s_snr_); num_exact_ += other.num_exact_; hist_weights_.Assimilate(other.hist_weights_); } void PrintAll() { const int skip = hwy::Stats::kNoGeomean; fprintf(stderr, " pnorm %s\n", s_pnorm_.ToString(skip).c_str()); fprintf(stderr, " SNR %s\n", s_snr_.ToString(skip).c_str()); fprintf(stderr, " #exact %.3E\n", static_cast(num_exact_)); // hist_weights_.Print("indices"); } void Reset() { s_pnorm_.Reset(); s_snr_.Reset(); num_exact_ = 0; hist_weights_.Reset(); } private: hwy::Stats s_pnorm_; hwy::Stats s_snr_; size_t num_exact_ = 0; hwy::Bins<1000> hist_weights_; char padding_[64]; // prevent false sharing }; #else class DistortionStats; struct CompressStats { void Notify(const DistortionStats&) {} void NotifyIn(int) {} void Assimilate(const CompressStats&) {} void PrintAll() {} void Reset() {} }; #endif // COMPRESS_STATS struct CompressPerThread { // Allocated the first time NUQ is used. std::unique_ptr buf; CompressStats stats; }; struct CompressWorkingSet { std::vector tls; }; // Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales // them such that the largest magnitude is `SfpStream::kMax`, and returns the // multiplier with which to restore the original values. This is only necessary // before compressing to `SfpStream` and `NuqStream`. float ScaleWeights(float* HWY_RESTRICT raw, size_t num); } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_