Fix setting scales in Py binding

PiperOrigin-RevId: 655284183
This commit is contained in:
The gemma.cpp Authors 2024-07-23 13:32:09 -07:00 committed by Copybara-Service
parent 2346b5a434
commit c1f243c351
1 changed files with 4 additions and 1 deletions

View File

@ -72,7 +72,9 @@ class SbsWriterImpl : public WriterInterface {
}
void AddScales(const std::vector<float>& scales) override {
compressor_.AddScales(scales.data(), scales.size());
HWY_ASSERT(scales_.empty());
scales_ = scales;
compressor_.AddScales(scales_.data(), scales_.size());
}
void Write(std::string path) override {
@ -85,6 +87,7 @@ class SbsWriterImpl : public WriterInterface {
std::vector<std::vector<SfpStream>> sfp_streams_;
std::vector<std::vector<NuqStream>> nuq_streams_;
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
std::vector<float> scales_;
};
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }