mirror of https://github.com/google/gemma.cpp.git
Fix setting scales in Py binding
PiperOrigin-RevId: 655284183
This commit is contained in:
parent
2346b5a434
commit
c1f243c351
|
|
@ -72,7 +72,9 @@ class SbsWriterImpl : public WriterInterface {
|
|||
}
|
||||
|
||||
void AddScales(const std::vector<float>& scales) override {
|
||||
compressor_.AddScales(scales.data(), scales.size());
|
||||
HWY_ASSERT(scales_.empty());
|
||||
scales_ = scales;
|
||||
compressor_.AddScales(scales_.data(), scales_.size());
|
||||
}
|
||||
|
||||
void Write(std::string path) override {
|
||||
|
|
@ -85,6 +87,7 @@ class SbsWriterImpl : public WriterInterface {
|
|||
std::vector<std::vector<SfpStream>> sfp_streams_;
|
||||
std::vector<std::vector<NuqStream>> nuq_streams_;
|
||||
std::vector<std::vector<hwy::bfloat16_t>> bf16_streams_;
|
||||
std::vector<float> scales_;
|
||||
};
|
||||
|
||||
WriterInterface* NewSbsWriter() { return new SbsWriterImpl(); }
|
||||
|
|
|
|||
Loading…
Reference in New Issue