mirror of https://github.com/google/gemma.cpp.git
Fix setting scales in Py binding
PiperOrigin-RevId: 655284183
This commit is contained in:
parent
2346b5a434
commit
c1f243c351
|
|
@ -72,7 +72,9 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
}
|
}
|
||||||
|
|
||||||
void AddScales(const std::vector<float>& scales) override {
|
void AddScales(const std::vector<float>& scales) override {
|
||||||
compressor_.AddScales(scales.data(), scales.size());
|
HWY_ASSERT(scales_.empty());
|
||||||
|
scales_ = scales;
|
||||||
|
compressor_.AddScales(scales_.data(), scales_.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
void Write(std::string path) override {
|
void Write(std::string path) override {
|
||||||
|
|
@ -85,6 +87,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
std::vector<std::vector<SfpStream>> sfp_streams_;
|
std::vector<std::vector<SfpStream>> sfp_streams_;
|
||||||
std::vector<std::vector<NuqStream>> nuq_streams_;
|
std::vector<std::vector<NuqStream>> nuq_streams_;
|
||||||
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
|
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
|
||||||
|
std::vector<float> scales_;
|
||||||
};
|
};
|
||||||
|
|
||||||
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
|
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue