Minor internal refactoring.

PiperOrigin-RevId: 635852078
This commit is contained in:
Paul Chang 2024-05-21 10:29:12 -07:00 committed by Copybara-Service
parent 59a1f87d63
commit c0643577c3
2 changed files with 20 additions and 6 deletions

View File

@ -527,14 +527,21 @@ class Compressor {
template <typename MatT, size_t kCapacity>
void operator()(const char* name, const float* weights,
CompressedArray<MatT, kCapacity>& compressed) {
fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name,
kCapacity / (1000 * 1000));
Compress(weights, kCapacity, work_, kCapacity, compressed.data(), 0, pool_);
writer_.Add(CacheKey<MatT>(name), compressed.data(),
compressed.CompressedSize());
Insert(name, weights, kCapacity, work_, compressed.CompressedSize(),
compressed.data(), 0, pool_);
}
void AddScales(float* scales, size_t len) {
template <typename MatT>
void Insert(const char* name, const float* weights, size_t weights_count,
CompressWorkingSet& work, size_t out_capacity, MatT* out,
size_t out_ofs, hwy::ThreadPool& pool) {
fprintf(stderr, "Regenerating %s (%zuM), please wait\n", name,
weights_count / (1000 * 1000));
Compress(weights, weights_count, work_, weights_count, out, 0, pool_);
writer_.Add(CacheKey<MatT>(name), out, out_capacity);
}
void AddScales(const float* scales, size_t len) {
if (len) {
writer_.Add(CacheKey<float>("scales"), scales, len * sizeof(scales[0]));
}

View File

@ -57,6 +57,13 @@ constexpr size_t CompressedArrayLen<NuqStream>(size_t capacity) {
}
} // namespace detail
// Returns the number of bytes required to store a compressed array with the
// given type and capacity.
template <typename MatT>
constexpr size_t CompressedArraySize(size_t capacity) {
return detail::CompressedArrayLen<MatT>(capacity) * sizeof(MatT);
}
// Compressed representation of floating-point elements. The array length may
// differ from the number of elements. Associated operations such as Dot are
// implemented in SIMD code and are thus non-member functions.