diff --git a/compression/compress-inl.h b/compression/compress-inl.h index a293681..dce62cd 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -527,14 +527,21 @@ class Compressor { template void operator()(const char* name, const float* weights, CompressedArray& compressed) { - fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name, - kCapacity / (1000 * 1000)); - Compress(weights, kCapacity, work_, kCapacity, compressed.data(), 0, pool_); - writer_.Add(CacheKey(name), compressed.data(), - compressed.CompressedSize()); + Insert(name, weights, kCapacity, work_, compressed.CompressedSize(), + compressed.data(), 0, pool_); } - void AddScales(float* scales, size_t len) { + template + void Insert(const char* name, const float* weights, size_t weights_count, + CompressWorkingSet& work, size_t out_capacity, MatT* out, + size_t out_ofs, hwy::ThreadPool& pool) { + fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name, + weights_count / (1000 * 1000)); + Compress(weights, weights_count, work_, weights_count, out, 0, pool_); + writer_.Add(CacheKey(name), out, out_capacity); + } + + void AddScales(const float* scales, size_t len) { if (len) { writer_.Add(CacheKey("scales"), scales, len * sizeof(scales[0])); } diff --git a/compression/compress.h b/compression/compress.h index 549ea6f..8870bb4 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -57,6 +57,13 @@ constexpr size_t CompressedArrayLen(size_t capacity) { } } // namespace detail +// Returns the number of bytes required to store a compressed array with the +// given type and capacity. +template +constexpr size_t CompressedArraySize(size_t capacity) { + return detail::CompressedArrayLen(capacity) * sizeof(MatT); +} + // Compressed representation of floating-point elements. The array length may // differ from the number of elements. Associated operations such as Dot are // implemented in SIMD code and are thus non-member functions.