Rename-only: remove Allocator2 etc suffixes now that refactoring is complete

PiperOrigin-RevId: 755397220
This commit is contained in:
Jan Wassenberg 2025-05-06 09:12:05 -07:00 committed by Copybara-Service
parent 8d0882b966
commit 275135d7e8
39 changed files with 215 additions and 216 deletions

View File

@ -66,14 +66,14 @@ hwy::ThreadPool& ThreadHostileGetPool() {
// can safely call `SetArgs` only once, because it would assert otherwise. // can safely call `SetArgs` only once, because it would assert otherwise.
// This is preferable to calling `ThreadHostileInvalidate`, because we would // This is preferable to calling `ThreadHostileInvalidate`, because we would
// repeat the topology initialization for every test. // repeat the topology initialization for every test.
if (!ThreadingContext2::IsInitialized()) { if (!ThreadingContext::IsInitialized()) {
gcpp::ThreadingArgs threading_args; gcpp::ThreadingArgs threading_args;
threading_args.max_packages = 1; threading_args.max_packages = 1;
threading_args.max_clusters = 8; threading_args.max_clusters = 8;
threading_args.pin = Tristate::kFalse; threading_args.pin = Tristate::kFalse;
ThreadingContext2::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
} }
return ThreadingContext2::Get().pools.Pool(); return ThreadingContext::Get().pools.Pool();
} }
void TestMatMulVJP() { void TestMatMulVJP() {
@ -203,7 +203,7 @@ void TestEndToEnd() {
std::vector<Prompt> batch = training_task.SampleBatch(3, gen); std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = CreateInvTimescale( RowVectorBatch<float> inv_timescale = CreateInvTimescale(
ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim, ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope); config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) { for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt); ReverseSequenceSampler::LogPrompt(prompt);

View File

@ -45,9 +45,9 @@ TEST(OptimizeTest, GradientDescent) {
threading_args.max_packages = 1; threading_args.max_packages = 1;
threading_args.max_clusters = 1; threading_args.max_clusters = 1;
threading_args.pin = Tristate::kFalse; threading_args.pin = Tristate::kFalse;
ThreadingContext2::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get()); MatMulEnv env(ThreadingContext::Get());
const Allocator2& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool(); hwy::ThreadPool& pool = env.ctx.pools.Pool();
std::mt19937 gen(42); std::mt19937 gen(42);

View File

@ -67,7 +67,7 @@ template <typename T>
class WeightsWrapper { class WeightsWrapper {
public: public:
explicit WeightsWrapper(const ModelConfig& config) : weights_(config) { explicit WeightsWrapper(const ModelConfig& config) : weights_(config) {
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool(); hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
weights_.AllocateForTest(owners_, pool); weights_.AllocateForTest(owners_, pool);
} }

View File

@ -35,7 +35,7 @@
namespace gcpp { namespace gcpp {
// Aborts if any keys differ, because then blobs are not comparable. // Aborts if any keys differ, because then blobs are not comparable.
void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) { void CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
if (reader1.Keys().size() != reader2.Keys().size()) { if (reader1.Keys().size() != reader2.Keys().size()) {
HWY_ABORT("#keys mismatch: %zu vs %zu\n", reader1.Keys().size(), HWY_ABORT("#keys mismatch: %zu vs %zu\n", reader1.Keys().size(),
reader2.Keys().size()); reader2.Keys().size());
@ -49,13 +49,13 @@ void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) {
} }
using KeyVec = std::vector<std::string>; using KeyVec = std::vector<std::string>;
using RangeVec = std::vector<BlobRange2>; using RangeVec = std::vector<BlobRange>;
RangeVec AllRanges(const KeyVec& keys, const BlobReader2& reader) { RangeVec AllRanges(const KeyVec& keys, const BlobReader& reader) {
RangeVec ranges; RangeVec ranges;
ranges.reserve(keys.size()); ranges.reserve(keys.size());
for (const std::string& key : keys) { for (const std::string& key : keys) {
const BlobRange2* range = reader.Find(key); const BlobRange* range = reader.Find(key);
if (!range) { if (!range) {
HWY_ABORT("Key %s not found, but was in KeyVec\n", key.c_str()); HWY_ABORT("Key %s not found, but was in KeyVec\n", key.c_str());
} }
@ -82,7 +82,7 @@ void CompareRangeSizes(const KeyVec& keys, const RangeVec& ranges1,
// Total amount to allocate for all blobs. // Total amount to allocate for all blobs.
size_t TotalBytes(const RangeVec& ranges) { size_t TotalBytes(const RangeVec& ranges) {
size_t total_bytes = 0; size_t total_bytes = 0;
for (const BlobRange2& range : ranges) { for (const BlobRange& range : ranges) {
total_bytes += range.bytes; total_bytes += range.bytes;
} }
return total_bytes; return total_bytes;
@ -95,7 +95,7 @@ using BlobVec = std::vector<ByteSpan>; // in order of keys
// Assigns pointers within the single allocation and updates `pos`. // Assigns pointers within the single allocation and updates `pos`.
BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) { BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
BlobVec blobs; BlobVec blobs;
for (const BlobRange2& range : ranges) { for (const BlobRange& range : ranges) {
blobs.push_back(ByteSpan(all_blobs.get() + pos, range.bytes)); blobs.push_back(ByteSpan(all_blobs.get() + pos, range.bytes));
pos += range.bytes; pos += range.bytes;
} }
@ -104,7 +104,7 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
// Reads one set of blobs in parallel (helpful if in disk cache). // Reads one set of blobs in parallel (helpful if in disk cache).
// Aborts on error. // Aborts on error.
void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs, void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_ASSERT(reader.Keys().size() == blobs.size()); HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size()); HWY_ASSERT(ranges.size() == blobs.size());
@ -116,7 +116,7 @@ void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs,
} }
// Parallelizes ReadBlobs across (two) packages, if available. // Parallelizes ReadBlobs across (two) packages, if available.
void ReadBothBlobs(BlobReader2& reader1, BlobReader2& reader2, void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const RangeVec& ranges1, const RangeVec& ranges2, const RangeVec& ranges1, const RangeVec& ranges2,
size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2, size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
NestedPools& pools) { NestedPools& pools) {
@ -215,8 +215,8 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
// Compares two sbs files, including blob order. // Compares two sbs files, including blob order.
void ReadAndCompareBlobs(const char* path1, const char* path2) { void ReadAndCompareBlobs(const char* path1, const char* path2) {
const Tristate map = Tristate::kFalse; const Tristate map = Tristate::kFalse;
std::unique_ptr<BlobReader2> reader1 = BlobReader2::Make(Path(path1), map); std::unique_ptr<BlobReader> reader1 = BlobReader::Make(Path(path1), map);
std::unique_ptr<BlobReader2> reader2 = BlobReader2::Make(Path(path2), map); std::unique_ptr<BlobReader> reader2 = BlobReader::Make(Path(path2), map);
if (!reader1 || !reader2) { if (!reader1 || !reader2) {
HWY_ABORT( HWY_ABORT(
"Failed to create readers for files %s %s, see error messages above.\n", "Failed to create readers for files %s %s, see error messages above.\n",
@ -235,7 +235,7 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos); BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos); BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
NestedPools& pools = ThreadingContext2::Get().pools; NestedPools& pools = ThreadingContext::Get().pools;
ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1, ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1,
blobs2, pools); blobs2, pools);

View File

@ -94,7 +94,7 @@ static_assert(sizeof(Header) == 16);
// Additional data may be added only inside new blobs. Changes to the blob // Additional data may be added only inside new blobs. Changes to the blob
// contents or type should be handled by renaming keys. // contents or type should be handled by renaming keys.
// //
// This class is for internal use by `BlobReader2` and `BlobWriter2`. Its // This class is for internal use by `BlobReader` and `BlobWriter`. Its
// interface is more low-level: fixed-size keys instead of strings. // interface is more low-level: fixed-size keys instead of strings.
class BlobStore { class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
@ -182,7 +182,7 @@ class BlobStore {
padded_dir_bytes - 2 * num_blobs * kU128Bytes); padded_dir_bytes - 2 * num_blobs * kU128Bytes);
// We already zero-initialized the directory padding; // We already zero-initialized the directory padding;
// `BlobWriter2::WriteAll` takes care of padding after each blob via an // `BlobWriter::WriteAll` takes care of padding after each blob via an
// additional I/O. // additional I/O.
for (size_t i = 0; i < num_blobs; ++i) { for (size_t i = 0; i < num_blobs; ++i) {
HWY_ASSERT(blobs[i].data() != nullptr); HWY_ASSERT(blobs[i].data() != nullptr);
@ -242,12 +242,12 @@ class BlobStore {
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const { void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
const size_t key_idx = 0; // not actually associated with a key/blob const size_t key_idx = 0; // not actually associated with a key/blob
writes.emplace_back( writes.emplace_back(
BlobRange2{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx}, BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
// members are const and BlobIO2 requires non-const pointers, and they // members are const and BlobIO2 requires non-const pointers, and they
// are not modified by file writes. // are not modified by file writes.
const_cast<Header*>(&header_)); const_cast<Header*>(&header_));
writes.emplace_back( writes.emplace_back(
BlobRange2{.offset = sizeof(header_), BlobRange{.offset = sizeof(header_),
.bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_), .bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_),
.key_idx = key_idx}, .key_idx = key_idx},
const_cast<hwy::uint128_t*>(directory_.data())); const_cast<hwy::uint128_t*>(directory_.data()));
@ -289,8 +289,8 @@ class BlobStore {
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`. std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
}; // BlobStore }; // BlobStore
BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes, BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, BlobReader2::Mode mode) const BlobStore& bs, BlobReader::Mode mode)
: file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) { : file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) {
HWY_ASSERT(file_ && file_bytes_ != 0); HWY_ASSERT(file_ && file_bytes_ != 0);
@ -306,12 +306,12 @@ BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
size_t bytes; size_t bytes;
bs.GetRange(key_idx, offset, bytes); bs.GetRange(key_idx, offset, bytes);
ranges_.emplace_back( ranges_.emplace_back(
BlobRange2{.offset = offset, .bytes = bytes, .key_idx = key_idx}); BlobRange{.offset = offset, .bytes = bytes, .key_idx = key_idx});
key_idx_for_key_[keys_[key_idx]] = key_idx; key_idx_for_key_[keys_[key_idx]] = key_idx;
} }
if (mode_ == Mode::kMap) { if (mode_ == Mode::kMap) {
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
// Verify `kEndAlign` is an upper bound on the page size. // Verify `kEndAlign` is an upper bound on the page size.
if (kEndAlign % allocator.BasePageBytes() != 0) { if (kEndAlign % allocator.BasePageBytes() != 0) {
HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.", HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.",
@ -338,12 +338,12 @@ BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
} }
} }
void BlobReader2::Enqueue(const BlobRange2& range, void* data) { void BlobReader::Enqueue(const BlobRange& range, void* data) {
// Debug-only because there may be many I/O requests (per row). // Debug-only because there may be many I/O requests (per row).
if constexpr (HWY_IS_DEBUG_BUILD) { if constexpr (HWY_IS_DEBUG_BUILD) {
HWY_DASSERT(!IsMapped()); HWY_DASSERT(!IsMapped());
HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr); HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr);
const BlobRange2& blob_range = Range(range.key_idx); const BlobRange& blob_range = Range(range.key_idx);
HWY_DASSERT(blob_range.End() <= file_bytes_); HWY_DASSERT(blob_range.End() <= file_bytes_);
if (range.End() > blob_range.End()) { if (range.End() > blob_range.End()) {
HWY_ABORT( HWY_ABORT(
@ -362,15 +362,15 @@ void BlobReader2::Enqueue(const BlobRange2& range, void* data) {
// TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX. // TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX.
// - O_DIRECT seems undesirable because we do want to use the OS cache // - O_DIRECT seems undesirable because we do want to use the OS cache
// between consecutive runs. // between consecutive runs.
void BlobReader2::ReadAll(hwy::ThreadPool& pool) const { void BlobReader::ReadAll(hwy::ThreadPool& pool) const {
PROFILER_ZONE("Startup.ReadAll"); PROFILER_ZONE("Startup.ReadAll");
HWY_ASSERT(!IsMapped()); HWY_ASSERT(!IsMapped());
// >5x speedup from parallel reads when cached. // >5x speedup from parallel reads when cached.
pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) { pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) {
const BlobRange2& range = requests_[i].range; const BlobRange& range = requests_[i].range;
const uint64_t end = range.End(); const uint64_t end = range.End();
const std::string& key = keys_[range.key_idx]; const std::string& key = keys_[range.key_idx];
const BlobRange2& blob_range = Range(range.key_idx); const BlobRange& blob_range = Range(range.key_idx);
HWY_ASSERT(blob_range.End() <= file_bytes_); HWY_ASSERT(blob_range.End() <= file_bytes_);
if (end > blob_range.End()) { if (end > blob_range.End()) {
HWY_ABORT( HWY_ABORT(
@ -387,11 +387,11 @@ void BlobReader2::ReadAll(hwy::ThreadPool& pool) const {
} }
// Decides whether to read or map the file. // Decides whether to read or map the file.
static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) { static BlobReader::Mode ChooseMode(uint64_t file_mib, Tristate map) {
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
// User has explicitly requested a map or read via args. // User has explicitly requested a map or read via args.
if (map == Tristate::kTrue) return BlobReader2::Mode::kMap; if (map == Tristate::kTrue) return BlobReader::Mode::kMap;
if (map == Tristate::kFalse) return BlobReader2::Mode::kRead; if (map == Tristate::kFalse) return BlobReader::Mode::kRead;
// Else: use heuristics to choose. Note that `FreeMiB` is generally low // Else: use heuristics to choose. Note that `FreeMiB` is generally low
// because idle memory is used as cache, so do not use it to decide. // because idle memory is used as cache, so do not use it to decide.
const size_t total_mib = allocator.TotalMiB(); const size_t total_mib = allocator.TotalMiB();
@ -400,13 +400,13 @@ static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) {
static_cast<size_t>(file_mib), total_mib); static_cast<size_t>(file_mib), total_mib);
} }
// Large fraction of total. // Large fraction of total.
if (file_mib >= total_mib / 3) return BlobReader2::Mode::kMap; if (file_mib >= total_mib / 3) return BlobReader::Mode::kMap;
// Big enough that even parallel loading wouldn't be quick. // Big enough that even parallel loading wouldn't be quick.
if (file_mib > 50 * 1024) return BlobReader2::Mode::kMap; if (file_mib > 50 * 1024) return BlobReader::Mode::kMap;
return BlobReader2::Mode::kRead; return BlobReader::Mode::kRead;
} }
std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path, std::unique_ptr<BlobReader> BlobReader::Make(const Path& blob_path,
const Tristate map) { const Tristate map) {
if (blob_path.Empty()) HWY_ABORT("No --weights specified."); if (blob_path.Empty()) HWY_ABORT("No --weights specified.");
std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r"); std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r");
@ -417,10 +417,10 @@ std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path,
// Even if `kMap`, read the directory via the `kRead` mode for simplicity. // Even if `kMap`, read the directory via the `kRead` mode for simplicity.
BlobStore bs(*file); BlobStore bs(*file);
if (!bs.IsValid(file_bytes)) { if (!bs.IsValid(file_bytes)) {
return std::unique_ptr<BlobReader2>(); // IsValid already printed a warning return std::unique_ptr<BlobReader>(); // IsValid already printed a warning
} }
return std::unique_ptr<BlobReader2>(new BlobReader2( return std::unique_ptr<BlobReader>(new BlobReader(
std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map))); std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map)));
} }
@ -434,14 +434,13 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
for (; offset <= end - kChunkBytes; for (; offset <= end - kChunkBytes;
offset += kChunkBytes, data += kChunkBytes) { offset += kChunkBytes, data += kChunkBytes) {
writes.emplace_back( writes.emplace_back(
BlobRange2{ BlobRange{.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
data); data);
} }
} }
if (offset != end) { if (offset != end) {
writes.emplace_back( writes.emplace_back(
BlobRange2{.offset = offset, .bytes = end - offset, .key_idx = key_idx}, BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
data); data);
} }
} }
@ -472,7 +471,7 @@ static void EnqueueWritesForBlobs(const BlobStore& bs,
if (padding != 0) { if (padding != 0) {
HWY_ASSERT(padding <= kBlobAlign); HWY_ASSERT(padding <= kBlobAlign);
writes.emplace_back( writes.emplace_back(
BlobRange2{ BlobRange{
.offset = offset + bytes, .bytes = padding, .key_idx = key_idx}, .offset = offset + bytes, .bytes = padding, .key_idx = key_idx},
const_cast<uint8_t*>(kZeros)); const_cast<uint8_t*>(kZeros));
} }
@ -484,19 +483,19 @@ static void EnqueueWritesForBlobs(const BlobStore& bs,
// remain alive until the last I/O is done. // remain alive until the last I/O is done.
zeros.resize(padding); zeros.resize(padding);
writes.emplace_back( writes.emplace_back(
BlobRange2{.offset = file_end, .bytes = padding, .key_idx = 0}, BlobRange{.offset = file_end, .bytes = padding, .key_idx = 0},
zeros.data()); zeros.data());
} }
} }
void BlobWriter2::Add(const std::string& key, const void* data, size_t bytes) { void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
HWY_ASSERT(data != nullptr); HWY_ASSERT(data != nullptr);
HWY_ASSERT(bytes != 0); HWY_ASSERT(bytes != 0);
keys_.push_back(KeyFromString(key.c_str())); keys_.push_back(KeyFromString(key.c_str()));
blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes); blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes);
} }
void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) { void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
const size_t num_blobs = keys_.size(); const size_t num_blobs = keys_.size();
HWY_ASSERT(num_blobs != 0); HWY_ASSERT(num_blobs != 0);
HWY_ASSERT(num_blobs == blobs_.size()); HWY_ASSERT(num_blobs == blobs_.size());
@ -516,7 +515,7 @@ void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
pool.Run(0, writes.size(), pool.Run(0, writes.size(),
[this, &file, &writes](uint64_t i, size_t /*thread*/) { [this, &file, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange2& range = writes[i].range; const BlobRange& range = writes[i].range;
if (!file->Write(writes[i].data, range.bytes, range.offset)) { if (!file->Write(writes[i].data, range.bytes, range.offset)) {
const std::string& key = StringFromKey(keys_[range.key_idx]); const std::string& key = StringFromKey(keys_[range.key_idx]);

View File

@ -35,20 +35,20 @@
namespace gcpp { namespace gcpp {
// One blob's extents within the file. // One blob's extents within the file.
struct BlobRange2 { struct BlobRange {
uint64_t End() const { return offset + bytes; } uint64_t End() const { return offset + bytes; }
uint64_t offset = 0; uint64_t offset = 0;
size_t bytes = 0; // We check blobs are not zero-sized. size_t bytes = 0; // We check blobs are not zero-sized.
// Index within `BlobReader2::Keys()` for error reporting. // Index within `BlobReader::Keys()` for error reporting.
size_t key_idx; size_t key_idx;
}; };
// A read or write I/O request, each serviced by one thread in a pool. // A read or write I/O request, each serviced by one thread in a pool.
struct BlobIO2 { struct BlobIO2 {
BlobIO2(BlobRange2 range, void* data) : range(range), data(data) {} BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
BlobRange2 range; BlobRange range;
void* data; // Modified only if a read request. Read-only for writes. void* data; // Modified only if a read request. Read-only for writes.
}; };
@ -59,7 +59,7 @@ class BlobStore;
// Thread-safe: it is safe to concurrently call all methods except `Enqueue`, // Thread-safe: it is safe to concurrently call all methods except `Enqueue`,
// because they are const. // because they are const.
// TODO(janwas): split into header and reader/mapper classes. // TODO(janwas): split into header and reader/mapper classes.
class BlobReader2 { class BlobReader {
public: public:
// Parallel I/O into allocated memory, or mapped view of file. The latter is // Parallel I/O into allocated memory, or mapped view of file. The latter is
// better when the file is huge, but page faults add noise to measurements. // better when the file is huge, but page faults add noise to measurements.
@ -67,26 +67,26 @@ class BlobReader2 {
// Acquires ownership of `file` (which must be non-null) and reads its header. // Acquires ownership of `file` (which must be non-null) and reads its header.
// Factory function instead of ctor because this can fail (return null). // Factory function instead of ctor because this can fail (return null).
static std::unique_ptr<BlobReader2> Make(const Path& blob_path, static std::unique_ptr<BlobReader> Make(const Path& blob_path,
Tristate map = Tristate::kDefault); Tristate map = Tristate::kDefault);
~BlobReader2() = default; ~BlobReader() = default;
// Returns true if the mode passed to ctor was `kMap` and mapping succeeded. // Returns true if the mode passed to ctor was `kMap` and mapping succeeded.
bool IsMapped() const { return mode_ == Mode::kMap; } bool IsMapped() const { return mode_ == Mode::kMap; }
const std::vector<std::string>& Keys() const { return keys_; } const std::vector<std::string>& Keys() const { return keys_; }
const BlobRange2& Range(size_t key_idx) const { const BlobRange& Range(size_t key_idx) const {
HWY_ASSERT(key_idx < keys_.size()); HWY_ASSERT(key_idx < keys_.size());
return ranges_[key_idx]; return ranges_[key_idx];
} }
// Returns nullptr if not found. O(1). // Returns nullptr if not found. O(1).
const BlobRange2* Find(const std::string& key) const { const BlobRange* Find(const std::string& key) const {
auto it = key_idx_for_key_.find(key); auto it = key_idx_for_key_.find(key);
if (it == key_idx_for_key_.end()) return nullptr; if (it == key_idx_for_key_.end()) return nullptr;
const BlobRange2& range = Range(it->second); const BlobRange& range = Range(it->second);
HWY_ASSERT(range.offset != 0 && range.bytes != 0); HWY_ASSERT(range.offset != 0 && range.bytes != 0);
HWY_ASSERT(range.End() <= file_bytes_); HWY_ASSERT(range.End() <= file_bytes_);
return &range; return &range;
@ -95,7 +95,7 @@ class BlobReader2 {
// Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that // Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that
// everything else except `CallWithSpan` is in units of bytes. // everything else except `CallWithSpan` is in units of bytes.
template <typename T> template <typename T>
hwy::Span<const T> MappedSpan(const BlobRange2& range) const { hwy::Span<const T> MappedSpan(const BlobRange& range) const {
HWY_ASSERT(IsMapped()); HWY_ASSERT(IsMapped());
HWY_ASSERT(range.bytes % sizeof(T) == 0); HWY_ASSERT(range.bytes % sizeof(T) == 0);
return hwy::Span<const T>( return hwy::Span<const T>(
@ -108,7 +108,7 @@ class BlobReader2 {
// which an aligned allocation is unnecessary. // which an aligned allocation is unnecessary.
template <typename T, class Func> template <typename T, class Func>
bool CallWithSpan(const std::string& key, const Func& func) const { bool CallWithSpan(const std::string& key, const Func& func) const {
const BlobRange2* range = Find(key); const BlobRange* range = Find(key);
if (!range) { if (!range) {
HWY_WARN("Blob %s not found, sizeof T=%zu", key.c_str(), sizeof(T)); HWY_WARN("Blob %s not found, sizeof T=%zu", key.c_str(), sizeof(T));
return false; return false;
@ -134,7 +134,7 @@ class BlobReader2 {
// The following methods must only be called if `!IsMapped()`. // The following methods must only be called if `!IsMapped()`.
// Enqueues a BlobIO2 for `ReadAll` to execute. // Enqueues a BlobIO2 for `ReadAll` to execute.
void Enqueue(const BlobRange2& range, void* data); void Enqueue(const BlobRange& range, void* data);
// Reads in parallel all enqueued requests to the specified destinations. // Reads in parallel all enqueued requests to the specified destinations.
// Aborts on error. // Aborts on error.
@ -142,7 +142,7 @@ class BlobReader2 {
private: private:
// Only for use by `Make`. // Only for use by `Make`.
BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes, BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, Mode mode); const BlobStore& bs, Mode mode);
const std::unique_ptr<File> file_; const std::unique_ptr<File> file_;
@ -150,7 +150,7 @@ class BlobReader2 {
Mode mode_; Mode mode_;
std::vector<std::string> keys_; std::vector<std::string> keys_;
std::vector<BlobRange2> ranges_; std::vector<BlobRange> ranges_;
std::unordered_map<std::string, size_t> key_idx_for_key_; std::unordered_map<std::string, size_t> key_idx_for_key_;
MapPtr mapped_; // only if `kMap` MapPtr mapped_; // only if `kMap`
@ -160,7 +160,7 @@ class BlobReader2 {
// Collects references to blobs and writes them all at once with parallel I/O. // Collects references to blobs and writes them all at once with parallel I/O.
// Thread-compatible: independent instances can be used concurrently, but it // Thread-compatible: independent instances can be used concurrently, but it
// does not make sense to call the methods concurrently. // does not make sense to call the methods concurrently.
class BlobWriter2 { class BlobWriter {
public: public:
void Add(const std::string& key, const void* data, size_t bytes); void Add(const std::string& key, const void* data, size_t bytes);

View File

@ -37,7 +37,7 @@ class BlobStoreTest : public testing::Test {};
#endif #endif
void TestWithMapped(Tristate map) { void TestWithMapped(Tristate map) {
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool(); hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828}; static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
@ -51,7 +51,7 @@ void TestWithMapped(Tristate map) {
const std::string keyA("0123456789abcdef"); // max 16 characters const std::string keyA("0123456789abcdef"); // max 16 characters
const std::string keyB("q"); const std::string keyB("q");
BlobWriter2 writer; BlobWriter writer;
writer.Add(keyA, "DATA", 5); writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer)); writer.Add(keyB, buffer.data(), sizeof(buffer));
writer.WriteAll(pool, path); writer.WriteAll(pool, path);
@ -59,14 +59,14 @@ void TestWithMapped(Tristate map) {
std::fill(buffer.begin(), buffer.end(), 0); std::fill(buffer.begin(), buffer.end(), 0);
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map); std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
HWY_ASSERT(reader); HWY_ASSERT(reader);
HWY_ASSERT_EQ(reader->Keys().size(), 2); HWY_ASSERT_EQ(reader->Keys().size(), 2);
HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str()); HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str());
HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str()); HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str());
const BlobRange2* range = reader->Find(keyA); const BlobRange* range = reader->Find(keyA);
HWY_ASSERT(range); HWY_ASSERT(range);
const uint64_t offsetA = range->offset; const uint64_t offsetA = range->offset;
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
@ -80,9 +80,9 @@ void TestWithMapped(Tristate map) {
if (!reader->IsMapped()) { if (!reader->IsMapped()) {
char str[5]; char str[5];
reader->Enqueue( reader->Enqueue(
BlobRange2{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str); BlobRange{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
reader->Enqueue( reader->Enqueue(
BlobRange2{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1}, BlobRange{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
buffer.data()); buffer.data());
reader->ReadAll(pool); reader->ReadAll(pool);
HWY_ASSERT_STRING_EQ("DATA", str); HWY_ASSERT_STRING_EQ("DATA", str);
@ -111,7 +111,7 @@ TEST(BlobStoreTest, TestReadWrite) {
// Ensures padding works for any number of random-sized blobs. // Ensures padding works for any number of random-sized blobs.
TEST(BlobStoreTest, TestNumBlobs) { TEST(BlobStoreTest, TestNumBlobs) {
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool(); hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
hwy::RandomState rng; hwy::RandomState rng;
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) { for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
@ -121,7 +121,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(fd > 0); HWY_ASSERT(fd > 0);
const Path path(path_str); const Path path(path_str);
BlobWriter2 writer; BlobWriter writer;
std::vector<std::string> keys; std::vector<std::string> keys;
keys.reserve(num_blobs); keys.reserve(num_blobs);
std::vector<std::vector<uint8_t>> blobs; std::vector<std::vector<uint8_t>> blobs;
@ -144,13 +144,13 @@ TEST(BlobStoreTest, TestNumBlobs) {
writer.WriteAll(pool, path); writer.WriteAll(pool, path);
const Tristate map = Tristate::kFalse; const Tristate map = Tristate::kFalse;
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map); std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
HWY_ASSERT(reader); HWY_ASSERT(reader);
HWY_ASSERT_EQ(reader->Keys().size(), num_blobs); HWY_ASSERT_EQ(reader->Keys().size(), num_blobs);
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) { pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(), HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(),
std::to_string(i).c_str()); std::to_string(i).c_str());
const BlobRange2* range = reader->Find(keys[i]); const BlobRange* range = reader->Find(keys[i]);
HWY_ASSERT(range); HWY_ASSERT(range);
HWY_ASSERT_EQ(blobs[i].size(), range->bytes); HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
HWY_ASSERT(reader->CallWithSpan<uint8_t>( HWY_ASSERT(reader->CallWithSpan<uint8_t>(

View File

@ -22,7 +22,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/blob_store.h" // BlobWriter2 #include "compression/blob_store.h" // BlobWriter
#include "compression/compress.h" // ScaleWeights #include "compression/compress.h" // ScaleWeights
#include "compression/io.h" // Path #include "compression/io.h" // Path
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
@ -88,7 +88,7 @@ class SbsWriterImpl : public ISbsWriter {
} }
public: public:
SbsWriterImpl() : pool_(ThreadingContext2::Get().pools.Pool()) {} SbsWriterImpl() : pool_(ThreadingContext::Get().pools.Pool()) {}
void Insert(const char* name, F32Span weights, Type type, void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override { const TensorInfo& tensor_info) override {
@ -123,7 +123,7 @@ class SbsWriterImpl : public ISbsWriter {
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
MatOwners mat_owners_; MatOwners mat_owners_;
CompressWorkingSet working_set_; CompressWorkingSet working_set_;
BlobWriter2 writer_; BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_; std::vector<uint32_t> serialized_mat_ptrs_;
}; };
@ -141,7 +141,7 @@ HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {} SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsReader::SbsReader(const std::string& path) SbsReader::SbsReader(const std::string& path)
: reader_(gcpp::BlobReader2::Make(Path(path))), model_(*reader_) {} : reader_(gcpp::BlobReader::Make(Path(path))), model_(*reader_) {}
} // namespace gcpp } // namespace gcpp
#endif // HWY_ONCE #endif // HWY_ONCE

View File

@ -77,8 +77,8 @@ class SbsReader {
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); } const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
private: private:
std::unique_ptr<gcpp::BlobReader2> reader_; std::unique_ptr<gcpp::BlobReader> reader_;
gcpp::ModelStore2 model_; gcpp::ModelStore model_;
}; };
} // namespace gcpp } // namespace gcpp

View File

@ -240,7 +240,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
char* dt = ctime(&now); // NOLINT char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown"; char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100); (void)hwy::platform::GetCpuString(cpu100);
const ThreadingContext2& ctx = ThreadingContext2::Get(); const ThreadingContext& ctx = ThreadingContext::Get();
fprintf(stderr, fprintf(stderr,
"Date & Time : %s" // dt includes \n "Date & Time : %s" // dt includes \n

View File

@ -50,7 +50,7 @@ class GemmaEnv {
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference); const InferenceArgs& inference);
// Avoid memory leaks in test. // Avoid memory leaks in test.
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); } ~GemmaEnv() { ThreadingContext::ThreadHostileInvalidate(); }
MatMulEnv& Env() { return env_; } MatMulEnv& Env() { return env_; }

View File

@ -72,7 +72,7 @@ struct Activations {
size_t cache_pos_size = 0; size_t cache_pos_size = 0;
void Allocate(size_t batch_size, MatMulEnv* env) { void Allocate(size_t batch_size, MatMulEnv* env) {
const Allocator2& allocator = env->ctx.allocator; const Allocator& allocator = env->ctx.allocator;
post_qk = layer_config.post_qk; post_qk = layer_config.post_qk;
const size_t model_dim = weights_config.model_dim; const size_t model_dim = weights_config.model_dim;

View File

@ -561,7 +561,7 @@ class GemmaAttention {
const LayerWeightsPtrs<T>& layer_weights_; const LayerWeightsPtrs<T>& layer_weights_;
const hwy::Divisor& div_seq_len_; const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_; const KVCaches& kv_caches_;
const Allocator2& allocator_; const Allocator& allocator_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
}; };
@ -749,7 +749,7 @@ class VitAttention {
Activations& activations_; Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_; const LayerWeightsPtrs<T>& layer_weights_;
const LayerConfig& layer_config_; const LayerConfig& layer_config_;
const Allocator2& allocator_; const Allocator& allocator_;
hwy::ThreadPool& pool_; hwy::ThreadPool& pool_;
}; };
@ -789,7 +789,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
const auto x = const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator2& allocator = activations.env->ctx.allocator; const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto multiplier = RowPtrFromBatch(allocator, activations.C2); auto multiplier = RowPtrFromBatch(allocator, activations.C2);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
@ -847,7 +847,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
const auto x = const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator2& allocator = activations.env->ctx.allocator; const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
@ -1416,7 +1416,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
// //
// `kv_caches` is for the batch, size must match `queries_prompt`. // `kv_caches` is for the batch, size must match `queries_prompt`.
template <typename T> template <typename T>
void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights, void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
Activations& activations, const RuntimeConfig& runtime_config, Activations& activations, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in, const QueriesPos& queries_pos_in,
@ -1508,7 +1508,7 @@ void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights,
} }
template <typename T> template <typename T>
void GenerateSingleT(const ModelStore2& model, void GenerateSingleT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end, const PromptTokens& prompt, size_t pos, size_t prefix_end,
@ -1532,7 +1532,7 @@ void GenerateSingleT(const ModelStore2& model,
} }
template <typename T> template <typename T>
void GenerateBatchT(const ModelStore2& model, void GenerateBatchT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPromptTokens& queries_prompt,
@ -1573,7 +1573,7 @@ void GenerateBatchT(const ModelStore2& model,
} }
template <typename T> template <typename T>
void GenerateImageTokensT(const ModelStore2& model, void GenerateImageTokensT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights, const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens, const Image& image, ImageTokens& image_tokens,
@ -1599,7 +1599,7 @@ void GenerateImageTokensT(const ModelStore2& model,
// These are extern functions defined by instantiations/*.cc, which include this // These are extern functions defined by instantiations/*.cc, which include this
// 'header' after defining `GEMMA_TYPE`. // 'header' after defining `GEMMA_TYPE`.
void GenerateSingle( // NOLINT(misc-definitions-in-headers) void GenerateSingle( // NOLINT(misc-definitions-in-headers)
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights, const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) { TimingInfo& timing_info) {
@ -1609,7 +1609,7 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers)
} }
void GenerateBatch( // NOLINT(misc-definitions-in-headers) void GenerateBatch( // NOLINT(misc-definitions-in-headers)
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights, const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos, const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches, const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
@ -1620,7 +1620,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers)
} }
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers) void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights, const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const Image& image, const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, MatMulEnv* env) { ImageTokens& image_tokens, MatMulEnv* env) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>) HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)

View File

@ -47,13 +47,13 @@ namespace gcpp {
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
// Placeholder for internal init, do not modify. // Placeholder for internal init, do not modify.
ThreadingContext2::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
return MatMulEnv(ThreadingContext2::Get()); return MatMulEnv(ThreadingContext::Get());
} }
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env) Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
: env_(env), : env_(env),
reader_(BlobReader2::Make(loader.weights, loader.map)), reader_(BlobReader::Make(loader.weights, loader.map)),
model_(*reader_, loader.tokenizer, loader.wrapping), model_(*reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config().weight), weights_(model_.Config().weight),
chat_template_(model_.Tokenizer(), model_.Config().model) { chat_template_(model_.Tokenizer(), model_.Config().model) {
@ -74,7 +74,7 @@ Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer,
Gemma::~Gemma() = default; Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const { void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
BlobWriter2 writer; BlobWriter writer;
const std::vector<uint32_t> serialized_mat_ptrs = const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer); weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs, WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
@ -90,17 +90,17 @@ void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
// instead of `WeightsPtrs<T>`. // instead of `WeightsPtrs<T>`.
#define GEMMA_DECLARE(WEIGHT_TYPE) \ #define GEMMA_DECLARE(WEIGHT_TYPE) \
extern void GenerateSingle( \ extern void GenerateSingle( \
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \ const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const PromptTokens& prompt, \ const RuntimeConfig& runtime_config, const PromptTokens& prompt, \
size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \ size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \
TimingInfo& timing_info); \ TimingInfo& timing_info); \
extern void GenerateBatch( \ extern void GenerateBatch( \
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \ const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \ const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \ const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \ const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
extern void GenerateImageTokens( \ extern void GenerateImageTokens( \
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \ const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const Image& image, \ const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, MatMulEnv* env); ImageTokens& image_tokens, MatMulEnv* env);
GEMMA_DECLARE(float) GEMMA_DECLARE(float)

View File

@ -160,8 +160,8 @@ class Gemma {
private: private:
MatMulEnv& env_; MatMulEnv& env_;
std::unique_ptr<BlobReader2> reader_; // null for second ctor std::unique_ptr<BlobReader> reader_; // null for second ctor
ModelStore2 model_; ModelStore model_;
WeightsOwner weights_; WeightsOwner weights_;
GemmaChatTemplate chat_template_; GemmaChatTemplate chat_template_;
}; };

View File

@ -56,7 +56,7 @@ static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
// Returns the serialized tokenizer (std::string is required for proto). // Returns the serialized tokenizer (std::string is required for proto).
// Reads it from a blob or from a separate file if pre-2025. // Reads it from a blob or from a separate file if pre-2025.
static std::string ReadTokenizer(BlobReader2& reader, static std::string ReadTokenizer(BlobReader& reader,
const Path& tokenizer_path) { const Path& tokenizer_path) {
std::string tokenizer; std::string tokenizer;
// Check prevents `CallWithSpan` from printing a warning. // Check prevents `CallWithSpan` from printing a warning.
@ -107,7 +107,7 @@ class TypePrefix {
} }
} }
TypePrefix(const KeyVec& keys, const BlobReader2& reader) { TypePrefix(const KeyVec& keys, const BlobReader& reader) {
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) { for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
const std::string& key = keys[key_idx]; const std::string& key = keys[key_idx];
const Type type = TypeFromChar(key[0]); const Type type = TypeFromChar(key[0]);
@ -200,7 +200,7 @@ static int DeduceLayerTypes(const KeyVec& keys) {
// `wrapping_override` is forwarded from the command line. For pre-2025 files // `wrapping_override` is forwarded from the command line. For pre-2025 files
// without `ModelConfig`, it is the only way to force PT. // without `ModelConfig`, it is the only way to force PT.
static ModelConfig ReadOrDeduceConfig(BlobReader2& reader, static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
Tristate wrapping_override) { Tristate wrapping_override) {
const TypePrefix type_prefix(reader.Keys(), reader); const TypePrefix type_prefix(reader.Keys(), reader);
Type deduced_weight = Type::kUnknown; Type deduced_weight = Type::kUnknown;
@ -244,7 +244,7 @@ static ModelConfig ReadOrDeduceConfig(BlobReader2& reader,
ChooseWrapping(config.model, wrapping_override)); ChooseWrapping(config.model, wrapping_override));
} }
static std::vector<float> ReadScales(BlobReader2& reader, static std::vector<float> ReadScales(BlobReader& reader,
const ModelConfig& config) { const ModelConfig& config) {
std::vector<float> scales; std::vector<float> scales;
// Check first to prevent `CallWithSpan` from printing a warning. This blob is // Check first to prevent `CallWithSpan` from printing a warning. This blob is
@ -260,7 +260,7 @@ static std::vector<float> ReadScales(BlobReader2& reader,
} }
// Single-file format: reads `MatPtr` from the blob; returns false if not found. // Single-file format: reads `MatPtr` from the blob; returns false if not found.
bool ModelStore2::ReadMatPtrs(BlobReader2& reader) { bool ModelStore::ReadMatPtrs(BlobReader& reader) {
// Check first to prevent `CallWithSpan` from printing a warning. // Check first to prevent `CallWithSpan` from printing a warning.
if (!reader.Find(kMatPtrsName)) return false; if (!reader.Find(kMatPtrsName)) return false;
@ -282,7 +282,7 @@ bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
// Retrieve actual key index because a writer may have written other // Retrieve actual key index because a writer may have written other
// blobs before the tensor data. // blobs before the tensor data.
const BlobRange2* range = reader.Find(mat.Name()); const BlobRange* range = reader.Find(mat.Name());
HWY_ASSERT(range); HWY_ASSERT(range);
const size_t key_idx = range->key_idx; const size_t key_idx = range->key_idx;
AddMatPtr(key_idx, mat); AddMatPtr(key_idx, mat);
@ -302,7 +302,7 @@ bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
} }
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`. // Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
void ModelStore2::CreateMatPtrs(BlobReader2& reader) { void ModelStore::CreateMatPtrs(BlobReader& reader) {
const TensorInfoRegistry tensors(config_); const TensorInfoRegistry tensors(config_);
const KeyVec& keys = reader.Keys(); const KeyVec& keys = reader.Keys();
@ -329,7 +329,7 @@ void ModelStore2::CreateMatPtrs(BlobReader2& reader) {
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size()); HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
} }
ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path, ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping) Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)), : config_(ReadOrDeduceConfig(reader, wrapping)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) { tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
@ -348,12 +348,12 @@ ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path,
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size()); HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
} }
ModelStore2::~ModelStore2() { ModelStore::~ModelStore() {
// Sanity check: ensure all scales were consumed. // Sanity check: ensure all scales were consumed.
HWY_ASSERT(scales_consumed_ == scales_.size()); HWY_ASSERT(scales_consumed_ == scales_.size());
} }
const MatPtr* ModelStore2::FindMat(const char* name) const { const MatPtr* ModelStore::FindMat(const char* name) const {
auto it = mat_idx_for_name_.find(name); auto it = mat_idx_for_name_.find(name);
if (it == mat_idx_for_name_.end()) return nullptr; if (it == mat_idx_for_name_.end()) return nullptr;
const size_t mat_idx = it->second; const size_t mat_idx = it->second;
@ -362,7 +362,7 @@ const MatPtr* ModelStore2::FindMat(const char* name) const {
return file_mat; return file_mat;
} }
bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const { bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
const MatPtr* file_mat = FindMat(mat.Name()); const MatPtr* file_mat = FindMat(mat.Name());
if (!file_mat) return false; if (!file_mat) return false;
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) { if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
@ -390,14 +390,14 @@ bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
} }
static void AddBlob(const char* name, const std::vector<uint32_t>& data, static void AddBlob(const char* name, const std::vector<uint32_t>& data,
BlobWriter2& writer) { BlobWriter& writer) {
HWY_ASSERT(!data.empty()); HWY_ASSERT(!data.empty());
writer.Add(name, data.data(), data.size() * sizeof(data[0])); writer.Add(name, data.data(), data.size() * sizeof(data[0]));
} }
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs, const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter2& writer, hwy::ThreadPool& pool, BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path) { const Path& path) {
HWY_ASSERT(config.model != Model::UNKNOWN); HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown); HWY_ASSERT(config.weight != Type::kUnknown);

View File

@ -48,16 +48,16 @@ namespace gcpp {
// tokenizer in a separate file, encoded tensor type in a prefix of the blob // tokenizer in a separate file, encoded tensor type in a prefix of the blob
// name, and had a blob for tensor scaling factors. We still support reading // name, and had a blob for tensor scaling factors. We still support reading
// both, but only write single-file format. // both, but only write single-file format.
class ModelStore2 { class ModelStore {
public: public:
// Reads from file(s) or aborts on error. The latter two arguments are only // Reads from file(s) or aborts on error. The latter two arguments are only
// used for pre-2025 files. // used for pre-2025 files.
ModelStore2(BlobReader2& reader, const Path& tokenizer_path = Path(), ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
Tristate wrapping = Tristate::kDefault); Tristate wrapping = Tristate::kDefault);
// For optimize_test.cc. // For optimize_test.cc.
ModelStore2(const ModelConfig& config, GemmaTokenizer&& tokenizer) ModelStore(const ModelConfig& config, GemmaTokenizer&& tokenizer)
: config_(config), tokenizer_(std::move(tokenizer)) {} : config_(config), tokenizer_(std::move(tokenizer)) {}
~ModelStore2(); ~ModelStore();
const ModelConfig& Config() const { const ModelConfig& Config() const {
HWY_ASSERT(config_.model != Model::UNKNOWN); HWY_ASSERT(config_.model != Model::UNKNOWN);
@ -72,7 +72,7 @@ class ModelStore2 {
// Returns false if `mat` is not available for loading, otherwise updates // Returns false if `mat` is not available for loading, otherwise updates
// `mat` with metadata from the file and sets `key_idx` for use by // `mat` with metadata from the file and sets `key_idx` for use by
// `BlobReader2`. Called via `ReadOrAllocate` in `weights.cc`. // `BlobReader`. Called via `ReadOrAllocate` in `weights.cc`.
bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const; bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const;
private: private:
@ -83,15 +83,15 @@ class ModelStore2 {
key_idx_.push_back(key_idx); key_idx_.push_back(key_idx);
} }
bool ReadMatPtrs(BlobReader2& reader); bool ReadMatPtrs(BlobReader& reader);
void CreateMatPtrs(BlobReader2& reader); // Aborts on error. void CreateMatPtrs(BlobReader& reader); // Aborts on error.
ModelConfig config_; ModelConfig config_;
GemmaTokenizer tokenizer_; GemmaTokenizer tokenizer_;
// All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`. // All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`.
std::vector<MatPtr> mat_ptrs_; std::vector<MatPtr> mat_ptrs_;
// For each of `mat_ptrs_`, the index within `BlobReader2::Keys()`. This is // For each of `mat_ptrs_`, the index within `BlobReader::Keys()`. This is
// not necessarily iota because some blobs are not tensors, and callers may // not necessarily iota because some blobs are not tensors, and callers may
// have added blobs before ours. // have added blobs before ours.
std::vector<size_t> key_idx_; std::vector<size_t> key_idx_;
@ -108,7 +108,7 @@ class ModelStore2 {
// produces a single BlobStore file holding everything required for inference. // produces a single BlobStore file holding everything required for inference.
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer, void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs, const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter2& writer, hwy::ThreadPool& pool, BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path); const Path& path);
} // namespace gcpp } // namespace gcpp

View File

@ -84,8 +84,8 @@ void LayerWeightsPtrs<NuqStream>::Reshape() {
} }
// Aborts on error. // Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader, static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange2>& ranges, const std::vector<BlobRange>& ranges,
MatOwners& mat_owners, const MatPadding padding, MatOwners& mat_owners, const MatPadding padding,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size()); HWY_ASSERT(mats.size() == ranges.size());
@ -121,7 +121,7 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes(); const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
uint8_t* row = mats[i]->RowT<uint8_t>(0); uint8_t* row = mats[i]->RowT<uint8_t>(0);
for (size_t r = 0; r < mats[i]->Rows(); ++r) { for (size_t r = 0; r < mats[i]->Rows(); ++r) {
reader.Enqueue(BlobRange2{.offset = offset, reader.Enqueue(BlobRange{.offset = offset,
.bytes = file_bytes_per_row, .bytes = file_bytes_per_row,
.key_idx = ranges[i].key_idx}, .key_idx = ranges[i].key_idx},
row); row);
@ -134,11 +134,11 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
reader.ReadAll(pool); reader.ReadAll(pool);
} }
void WeightsOwner::ReadOrAllocate(const ModelStore2& model, BlobReader2& reader, void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool) { hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from. // List of tensors to read/map, and where from.
std::vector<MatPtr*> mats; std::vector<MatPtr*> mats;
std::vector<BlobRange2> ranges; std::vector<BlobRange> ranges;
// Padding is inserted when reading row by row, except for NUQ tensors. // Padding is inserted when reading row by row, except for NUQ tensors.
const MatPadding padding = MatPadding::kOdd; const MatPadding padding = MatPadding::kOdd;
@ -244,7 +244,7 @@ void WeightsOwner::Reshape(hwy::ThreadPool& pool) {
} }
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter( std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
BlobWriter2& writer) const { BlobWriter& writer) const {
std::vector<uint32_t> serialized_mat_ptrs; std::vector<uint32_t> serialized_mat_ptrs;
CallT([&](const auto& weights) { CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) { weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {

View File

@ -25,7 +25,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "compression/blob_store.h" // BlobWriter2 #include "compression/blob_store.h" // BlobWriter
#include "compression/shared.h" // IsF32 #include "compression/shared.h" // IsF32
#include "gemma/configs.h" // ModelConfig #include "gemma/configs.h" // ModelConfig
#include "gemma/model_store.h" // ModelStore #include "gemma/model_store.h" // ModelStore
@ -519,7 +519,7 @@ class WeightsOwner {
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`, // Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`,
// allocates memory and reshapes. Aborts on error. // allocates memory and reshapes. Aborts on error.
void ReadOrAllocate(const ModelStore2& model, BlobReader2& reader, void ReadOrAllocate(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool); hwy::ThreadPool& pool);
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically // Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
@ -541,7 +541,7 @@ class WeightsOwner {
// For writers: // For writers:
// Adds one blob for each tensor's data and returns all serialized MatPtr. // Adds one blob for each tensor's data and returns all serialized MatPtr.
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter2& writer) const; std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
// For backprop/: // For backprop/:

View File

@ -79,7 +79,7 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// M = A rows, K = A cols, N = C cols. // M = A rows, K = A cols, N = C cols.
template <typename TA, typename TB = TA, typename TC = float> template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Allocator2& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool(0); hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
if (env.print_config || env.print_measurement) { if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -160,7 +160,7 @@ void BenchAllMatMul() {
return; return;
} }
ThreadingContext2& ctx = ThreadingContext2::Get(); ThreadingContext& ctx = ThreadingContext::Get();
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(), fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
ctx.pools.PinString()); ctx.pools.PinString());

View File

@ -999,7 +999,7 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d); const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot const hn::ScalableTag<float> df; // for CallDot
const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator; const Allocator& allocator = gcpp::ThreadingContext::Get().allocator;
CompressWorkingSet work; CompressWorkingSet work;
std::mt19937 rng; std::mt19937 rng;
rng.seed(12345); rng.seed(12345);
@ -1099,14 +1099,14 @@ void TestAllDot() {
constexpr size_t kMaxWorkers = 15; constexpr size_t kMaxWorkers = 15;
// Reset with cap on workers because we only support `kMaxWorkers`. // Reset with cap on workers because we only support `kMaxWorkers`.
ThreadingContext2::ThreadHostileInvalidate(); ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.max_packages = 1; threading_args.max_packages = 1;
threading_args.max_clusters = 1; threading_args.max_clusters = 1;
threading_args.max_lps = kMaxWorkers - 1; threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext2::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
ThreadingContext2& ctx = ThreadingContext2::Get(); ThreadingContext& ctx = ThreadingContext::Get();
const Allocator2& allocator = ctx.allocator; const Allocator& allocator = ctx.allocator;
{ // ensure no profiler zones are active { // ensure no profiler zones are active
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;

View File

@ -909,7 +909,7 @@ class MMPerPackage {
static constexpr size_t B_stride_max_ = static constexpr size_t B_stride_max_ =
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC); MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
static constexpr size_t B_storage_max_ = static constexpr size_t B_storage_max_ =
kNR * B_stride_max_ + Allocator2::MaxQuantum<BF16>(); kNR * B_stride_max_ + Allocator::MaxQuantum<BF16>();
// Granularity of `ForNP`. B rows produce C columns, so we // Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing. // want a multiple of the line size to prevent false sharing.
@ -1175,7 +1175,7 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`. // Autotuning wrapper for `DoDecompressA`.
template <typename TA> template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const { HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
const Allocator2& allocator = args_.env->ctx.allocator; const Allocator& allocator = args_.env->ctx.allocator;
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_]; MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view: // If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) { if constexpr (hwy::IsSame<TA, BF16>()) {
@ -1316,7 +1316,7 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B, HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env, const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) { const RowPtr<TC>& C) {
const Allocator2& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Extents().rows; const size_t M = A.Extents().rows;
const size_t K = A.Extents().cols; const size_t K = A.Extents().cols;
const size_t N = B.Extents().rows; const size_t N = B.Extents().rows;

View File

@ -60,7 +60,7 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables. // and holds most of their arguments in member variables.
class GenerateCandidates { class GenerateCandidates {
public: public:
GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N, GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr, size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config) const IndexRangePartition& ranges_np, bool print_config)
: allocator_(allocator), : allocator_(allocator),
@ -352,7 +352,7 @@ class GenerateCandidates {
} }
} }
const Allocator2& allocator_; const Allocator& allocator_;
const size_t M_; const size_t M_;
const size_t K_; const size_t K_;
const size_t N_; const size_t N_;
@ -372,7 +372,7 @@ class GenerateCandidates {
} // namespace } // namespace
// Facade to avoid exposing `GenerateCandidates` in the header. // Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M, std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC, size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np,
@ -384,7 +384,7 @@ std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package // memory accesses or false sharing, unless there are insufficient per-package
// rows for that. // rows for that.
static size_t NPMultiple(const Allocator2& allocator, size_t N, static size_t NPMultiple(const Allocator& allocator, size_t N,
size_t sizeof_TC, size_t nr, size_t num_packages) { size_t sizeof_TC, size_t nr, size_t num_packages) {
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC; size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
@ -417,7 +417,7 @@ IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages)); NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
} }
MatMulEnv::MatMulEnv(ThreadingContext2& ctx) MatMulEnv::MatMulEnv(ThreadingContext& ctx)
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) { : ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
char cpu100[100]; char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100); have_timer_stop = hwy::platform::HaveTimerStop(cpu100);

View File

@ -50,7 +50,7 @@ class MMParallel {
static constexpr size_t kMaxPackages = 4; static constexpr size_t kMaxPackages = 4;
// `ctx` must outlive this object. // `ctx` must outlive this object.
MMParallel(ThreadingContext2& ctx) : ctx_(ctx) { MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages); HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
} }
@ -164,11 +164,11 @@ class MMParallel {
} }
private: private:
ThreadingContext2& ctx_; ThreadingContext& ctx_;
}; };
template <typename TC> // BF16/float for C, double for partial template <typename TC> // BF16/float for C, double for partial
void BindC(const Allocator2& allocator, size_t M, const RowPtr<TC>& C, void BindC(const Allocator& allocator, size_t M, const RowPtr<TC>& C,
MMParallel& parallel) { MMParallel& parallel) {
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;
@ -207,7 +207,7 @@ class MMStorage {
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024; static constexpr size_t kMaxKC = 8 * 1024;
MMStorage(const Allocator2& allocator, MMParallel& parallel) MMStorage(const Allocator& allocator, MMParallel& parallel)
// Per-worker copies of `partial` would be wasteful. We instead allocate // Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at // one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity. // false-sharing-free granularity.
@ -236,7 +236,7 @@ class MMStorage {
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is // Returns per-package matrix view. Non-const so that `RowVectorBatch` is
// non-const, because `RowPtr` requires a non-const pointer. // non-const, because `RowPtr` requires a non-const pointer.
RowPtrBF A(const Allocator2& allocator, size_t pkg_idx, RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
const Extents2D& extents) { const Extents2D& extents) {
HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK); HWY_DASSERT(extents.cols <= kMaxK);
@ -430,7 +430,7 @@ class MMConfig {
static_assert(sizeof(MMConfig) == 32); // for faster indexing static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop) #pragma pack(pop)
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M, std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC, size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, const IndexRangePartition& ranges_np,
@ -561,7 +561,7 @@ class MMKeys {
} }
// Must only be called if not already present in `Keys()`. // Must only be called if not already present in `Keys()`.
void Append(Key key, const Allocator2& allocator) { void Append(Key key, const Allocator& allocator) {
// Dynamic allocation because the test checks many more dimensions than // Dynamic allocation because the test checks many more dimensions than
// would be reasonable to pre-allocate. DIY for alignment and padding. // would be reasonable to pre-allocate. DIY for alignment and padding.
if (HWY_UNLIKELY(num_unique_ >= capacity_)) { if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
@ -608,9 +608,9 @@ struct MMPerKey {
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive // Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`. // `MatMulEnv`.
struct MatMulEnv { struct MatMulEnv {
explicit MatMulEnv(ThreadingContext2& ctx); explicit MatMulEnv(ThreadingContext& ctx);
ThreadingContext2& ctx; ThreadingContext& ctx;
bool have_timer_stop = false; bool have_timer_stop = false;
// Whether `MMCandidates()` should print the set of parameters. // Whether `MMCandidates()` should print the set of parameters.
@ -753,7 +753,7 @@ ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
} }
template <typename TB> template <typename TB>
void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC, void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
const ConstMat<TB>& B, MMParallel& parallel) { const ConstMat<TB>& B, MMParallel& parallel) {
if (!allocator.ShouldBind()) return; if (!allocator.ShouldBind()) return;

View File

@ -86,7 +86,7 @@ float MaxAbs(const RowVectorBatch<float>& a) {
template <typename TA, typename TB, typename TC> template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B, void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) { const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
const hn::ScalableTag<float> df; const hn::ScalableTag<float> df;
const size_t cols = A.extents.cols; const size_t cols = A.extents.cols;
const size_t B_rows = B.extents.rows; const size_t B_rows = B.extents.rows;
@ -210,7 +210,7 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
template <typename TA, typename TB = TA, typename TC = float> template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) { MatMulEnv& env, int line) {
const Allocator2& allocator = env.ctx.allocator; const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool(); hwy::ThreadPool& pool = env.ctx.pools.Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(), rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
@ -259,12 +259,12 @@ void TestTiny() {
if (HWY_TARGET != first_target) return; if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) { for (size_t max_packages : {1, 2}) {
ThreadingContext2::ThreadHostileInvalidate(); ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue; threading_args.bind = Tristate::kTrue;
threading_args.max_packages = max_packages; threading_args.max_packages = max_packages;
ThreadingContext2::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get()); MatMulEnv env(ThreadingContext::Get());
NestedPools& pools = env.ctx.pools; NestedPools& pools = env.ctx.pools;
#if GEMMA_DISABLE_TOPOLOGY #if GEMMA_DISABLE_TOPOLOGY
@ -296,11 +296,11 @@ void TestAllMatMul() {
return; return;
} }
ThreadingContext2::ThreadHostileInvalidate(); ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args; ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue; threading_args.bind = Tristate::kTrue;
ThreadingContext2::SetArgs(threading_args); ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get()); MatMulEnv env(ThreadingContext::Get());
NestedPools& pools = env.ctx.pools; NestedPools& pools = env.ctx.pools;
pools.MaybeStartSpinning(threading_args.spin); pools.MaybeStartSpinning(threading_args.spin);

View File

@ -808,7 +808,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Each output row is the average of a 4x4 block of input rows // Each output row is the average of a 4x4 block of input rows
template <typename T> template <typename T>
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) { RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
const Extents2D extents = input.Extents(); const Extents2D extents = input.Extents();
// Input validation // Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows

View File

@ -27,7 +27,7 @@
namespace gcpp { namespace gcpp {
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale( static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
const Allocator2& allocator, size_t qkv_dim, bool half_rope, const Allocator& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) { double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2)); RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));

View File

@ -386,7 +386,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
} }
void TestRopeAndMulBy() { void TestRopeAndMulBy() {
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
ModelConfig config(Model::GEMMA2_9B, Type::kSFP, ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B)); ChooseWrapping(Model::GEMMA2_9B));

View File

@ -47,7 +47,7 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) { void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr); ASSERT_NE(s_env->GetGemma(), nullptr);
const Allocator2& allocator = s_env->Env().ctx.allocator; const Allocator& allocator = s_env->Env().ctx.allocator;
Gemma& gemma = *(s_env->GetGemma()); Gemma& gemma = *(s_env->GetGemma());
image_tokens_ = ImageTokens( image_tokens_ = ImageTokens(
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len, allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,

View File

@ -168,7 +168,7 @@ class GemmaModel {
void SetImage(const py::array_t<float, py::array::c_style | void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) { py::array::forcecast>& image) {
const gcpp::Gemma& gemma = *gemma_.GetGemma(); const gcpp::Gemma& gemma = *gemma_.GetGemma();
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator; const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator;
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA && if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) { gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
throw std::invalid_argument("Not a PaliGemma model."); throw std::invalid_argument("Not a PaliGemma model.");

View File

@ -130,7 +130,7 @@ size_t DetectTotalMiB(size_t page_bytes) {
} // namespace } // namespace
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) { Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
line_bytes_ = DetectLineBytes(); line_bytes_ = DetectLineBytes();
vector_bytes_ = hwy::VectorBytes(); vector_bytes_ = hwy::VectorBytes();
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
@ -180,7 +180,7 @@ Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1; quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1;
} }
size_t Allocator2::FreeMiB() const { size_t Allocator::FreeMiB() const {
#if HWY_OS_LINUX #if HWY_OS_LINUX
const long ret = sysconf(_SC_AVPHYS_PAGES); // NOLINT(runtime/int) const long ret = sysconf(_SC_AVPHYS_PAGES); // NOLINT(runtime/int)
HWY_ASSERT(ret != -1); HWY_ASSERT(ret != -1);
@ -201,7 +201,7 @@ size_t Allocator2::FreeMiB() const {
#endif #endif
} }
AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const { AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
// If we are not binding, the Highway allocator is cheaper than `mmap`, and // If we are not binding, the Highway allocator is cheaper than `mmap`, and
// defends against 2K aliasing. // defends against 2K aliasing.
if (!should_bind_) { if (!should_bind_) {
@ -296,7 +296,7 @@ size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
return num_busy; return num_busy;
} }
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const { bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) const {
HWY_DASSERT(should_bind_); HWY_DASSERT(should_bind_);
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
@ -353,7 +353,7 @@ bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
} }
#else #else
bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; } bool Allocator::BindMemory(void*, size_t, size_t) const { return false; }
#endif // GEMMA_BIND && HWY_OS_LINUX #endif // GEMMA_BIND && HWY_OS_LINUX
} // namespace gcpp } // namespace gcpp

View File

@ -78,14 +78,14 @@ template <typename T>
using AlignedClassPtr2 = std::unique_ptr<T, DeleterDtor2>; using AlignedClassPtr2 = std::unique_ptr<T, DeleterDtor2>;
// Both allocation, binding, and row accessors depend on the sizes of memory // Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator2&` everywhere, we // pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
// wrap this in a singleton. A monostate requires explicit initialization, // wrap this in a singleton. A monostate requires explicit initialization,
// which we prefer to avoid because there are many main() functions. // which we prefer to avoid because there are many main() functions.
class Allocator2 { class Allocator {
public: public:
// Must be called at least once before any other function. Not thread-safe, // Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread. // hence only call this from the main thread.
Allocator2(const BoundedTopology& topology, bool enable_bind); Allocator(const BoundedTopology& topology, bool enable_bind);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose // Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing. // ranges such that there will be no false sharing.

View File

@ -102,7 +102,7 @@ static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
return padded_num; return padded_num;
} }
static size_t Stride(const Allocator2& allocator, const MatPtr& mat, static size_t Stride(const Allocator& allocator, const MatPtr& mat,
MatPadding padding) { MatPadding padding) {
switch (padding) { switch (padding) {
case MatPadding::kPacked: case MatPadding::kPacked:
@ -119,7 +119,7 @@ static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked; if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
const Allocator2& allocator = ThreadingContext2::Get().allocator; const Allocator& allocator = ThreadingContext::Get().allocator;
const size_t stride = Stride(allocator, mat, padding); const size_t stride = Stride(allocator, mat, padding);
const size_t num = mat.Rows() * stride; const size_t num = mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`

View File

@ -282,11 +282,11 @@ void ZeroInit(MatPtr& mat);
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen); void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If // Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB. // `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is typically 4KiB.
// To avoid remote accesses, we would thus pad each row to that, which results // To avoid remote accesses, we would thus pad each row to that, which results
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that // in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
// by pulling rows forward by a cyclic offset, which is still a multiple of the // by pulling rows forward by a cyclic offset, which is still a multiple of the
// cache line size. This requires an additional `Allocator2::QuantumBytes()` of // cache line size. This requires an additional `Allocator::QuantumBytes()` of
// padding after also rounding up to that, which considerably increases size for // padding after also rounding up to that, which considerably increases size for
// tall and skinny tensors. // tall and skinny tensors.
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) { static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
@ -295,7 +295,7 @@ static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
// Constexpr version (upper bound) for allocating storage in MatMul. // Constexpr version (upper bound) for allocating storage in MatMul.
template <typename T> template <typename T>
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) { constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
constexpr size_t quantum = Allocator2::MaxQuantum<T>(); constexpr size_t quantum = Allocator::MaxQuantum<T>();
return hwy::RoundUpTo(cols, quantum) + quantum; return hwy::RoundUpTo(cols, quantum) + quantum;
} }
@ -387,7 +387,7 @@ MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
template <typename T> template <typename T>
class RowPtr { class RowPtr {
public: public:
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols, RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols,
size_t stride) size_t stride)
: row0_(row0), : row0_(row0),
stride_(stride), stride_(stride),
@ -414,7 +414,7 @@ class RowPtr {
} }
} }
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols) RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols)
: RowPtr(allocator, row0, cols, cols) {} : RowPtr(allocator, row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const { T* HWY_RESTRICT Row(size_t r) const {
@ -480,7 +480,7 @@ class RowVectorBatch {
// we default to tightly packed rows (`stride = cols`). // we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols. // WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here. // TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(const Allocator2& allocator, Extents2D extents, RowVectorBatch(const Allocator& allocator, Extents2D extents,
size_t stride = 0) size_t stride = 0)
: extents_(extents) { : extents_(extents) {
if (stride == 0) { if (stride == 0) {
@ -529,14 +529,14 @@ class RowVectorBatch {
}; };
template <typename T> template <typename T>
RowPtr<T> RowPtrFromBatch(const Allocator2& allocator, RowPtr<T> RowPtrFromBatch(const Allocator& allocator,
RowVectorBatch<T>& row_vectors) { RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(), return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
row_vectors.Stride()); row_vectors.Stride());
} }
template <typename T> template <typename T>
RowVectorBatch<T> AllocateAlignedRows(const Allocator2& allocator, RowVectorBatch<T> AllocateAlignedRows(const Allocator& allocator,
Extents2D extents) { Extents2D extents) {
return RowVectorBatch<T>( return RowVectorBatch<T>(
allocator, extents, allocator, extents,

View File

@ -109,7 +109,7 @@ static Pinning& GetPinning() {
return pinning; return pinning;
} }
static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers, static PoolPtr MakePool(const Allocator& allocator, size_t num_workers,
std::optional<size_t> node = std::nullopt) { std::optional<size_t> node = std::nullopt) {
// `ThreadPool` expects the number of threads to create, which is one less // `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero. // than the number of workers, but avoid underflow if zero.
@ -136,7 +136,7 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
} }
NestedPools::NestedPools(const BoundedTopology& topology, NestedPools::NestedPools(const BoundedTopology& topology,
const Allocator2& allocator, size_t max_threads, const Allocator& allocator, size_t max_threads,
Tristate pin) { Tristate pin) {
GetPinning().SetPolicy(pin); GetPinning().SetPolicy(pin);
packages_.resize(topology.NumPackages()); packages_.resize(topology.NumPackages());
@ -175,7 +175,7 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
} }
NestedPools::Package::Package(const BoundedTopology& topology, NestedPools::Package::Package(const BoundedTopology& topology,
const Allocator2& allocator, size_t pkg_idx, const Allocator& allocator, size_t pkg_idx,
size_t max_workers_per_package) { size_t max_workers_per_package) {
// Pre-allocate because elements are set concurrently. // Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(pkg_idx)); clusters_.resize(topology.NumClusters(pkg_idx));

View File

@ -74,7 +74,7 @@ class NestedPools {
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments // would cause huge slowdowns when spinning, the `BoundedSlice` arguments
// only impose upper bounds on the number of detected packages and clusters // only impose upper bounds on the number of detected packages and clusters
// rather than defining the actual number of threads. // rather than defining the actual number of threads.
NestedPools(const BoundedTopology& topology, const Allocator2& allocator, NestedPools(const BoundedTopology& topology, const Allocator& allocator,
size_t max_threads = 0, Tristate pin = Tristate::kDefault); size_t max_threads = 0, Tristate pin = Tristate::kDefault);
bool AllPinned() const { return all_pinned_; } bool AllPinned() const { return all_pinned_; }
@ -148,7 +148,7 @@ class NestedPools {
class Package { class Package {
public: public:
Package() = default; // for vector Package() = default; // for vector
Package(const BoundedTopology& topology, const Allocator2& allocator, Package(const BoundedTopology& topology, const Allocator& allocator,
size_t pkg_idx, size_t max_workers_per_package); size_t pkg_idx, size_t max_workers_per_package);
size_t NumClusters() const { return clusters_.size(); } size_t NumClusters() const { return clusters_.size(); }

View File

@ -26,37 +26,37 @@ namespace gcpp {
static ThreadingArgs s_args; static ThreadingArgs s_args;
// Cannot use magic static because that does not support `Invalidate`, hence // Cannot use magic static because that does not support `Invalidate`, hence
// allocate manually. // allocate manually.
static std::unique_ptr<ThreadingContext2> s_ctx; static std::unique_ptr<ThreadingContext> s_ctx;
static std::mutex s_ctx_mutex; static std::mutex s_ctx_mutex;
/*static*/ void ThreadingContext2::SetArgs(const ThreadingArgs& args) { /*static*/ void ThreadingContext::SetArgs(const ThreadingArgs& args) {
s_ctx_mutex.lock(); s_ctx_mutex.lock();
HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late. HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late.
s_args = args; s_args = args;
s_ctx_mutex.unlock(); s_ctx_mutex.unlock();
} }
/*static*/ bool ThreadingContext2::IsInitialized() { /*static*/ bool ThreadingContext::IsInitialized() {
s_ctx_mutex.lock(); s_ctx_mutex.lock();
const bool initialized = !!s_ctx; const bool initialized = !!s_ctx;
s_ctx_mutex.unlock(); s_ctx_mutex.unlock();
return initialized; return initialized;
} }
/*static*/ ThreadingContext2& ThreadingContext2::Get() { /*static*/ ThreadingContext& ThreadingContext::Get() {
PROFILER_FUNC; PROFILER_FUNC;
// We do not bother with double-checked locking because it requires an // We do not bother with double-checked locking because it requires an
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also, // atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
// callers can cache the result and call less often. // callers can cache the result and call less often.
s_ctx_mutex.lock(); s_ctx_mutex.lock();
if (HWY_UNLIKELY(!s_ctx)) { if (HWY_UNLIKELY(!s_ctx)) {
s_ctx = std::make_unique<ThreadingContext2>(PrivateToken()); s_ctx = std::make_unique<ThreadingContext>(PrivateToken());
} }
s_ctx_mutex.unlock(); s_ctx_mutex.unlock();
return *s_ctx; return *s_ctx;
} }
/*static*/ void ThreadingContext2::ThreadHostileInvalidate() { /*static*/ void ThreadingContext::ThreadHostileInvalidate() {
// Deliberately avoid taking the lock so that tsan can warn if this is // Deliberately avoid taking the lock so that tsan can warn if this is
// called concurrently with other calls to `Get`. // called concurrently with other calls to `Get`.
s_ctx.reset(); s_ctx.reset();
@ -64,7 +64,7 @@ static std::mutex s_ctx_mutex;
// WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would // WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would
// deadlock. // deadlock.
ThreadingContext2::ThreadingContext2(ThreadingContext2::PrivateToken) ThreadingContext::ThreadingContext(ThreadingContext::PrivateToken)
: topology(BoundedSlice(s_args.skip_packages, s_args.max_packages), : topology(BoundedSlice(s_args.skip_packages, s_args.max_packages),
BoundedSlice(s_args.skip_clusters, s_args.max_clusters), BoundedSlice(s_args.skip_clusters, s_args.max_clusters),
BoundedSlice(s_args.skip_lps, s_args.max_lps)), BoundedSlice(s_args.skip_lps, s_args.max_lps)),

View File

@ -87,7 +87,7 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
// Lazily-initialized singleton with support for passing in arguments from // Lazily-initialized singleton with support for passing in arguments from
// `ThreadingArgs` and re-initializing with different arguments. // `ThreadingArgs` and re-initializing with different arguments.
class ThreadingContext2 { class ThreadingContext {
struct PrivateToken {}; // avoids constructing directly struct PrivateToken {}; // avoids constructing directly
public: public:
@ -112,7 +112,7 @@ class ThreadingContext2 {
// hence we prefer not to pull `std::shared_ptr` into the interface. // hence we prefer not to pull `std::shared_ptr` into the interface.
// //
// To reduce overhead, callers should cache the result and call less often. // To reduce overhead, callers should cache the result and call less often.
static ThreadingContext2& Get(); static ThreadingContext& Get();
// Invalidates the singleton before or after a call to `Get`. This allows // Invalidates the singleton before or after a call to `Get`. This allows
// changing the arguments between tests. Callers must again call `Get` // changing the arguments between tests. Callers must again call `Get`
@ -121,10 +121,10 @@ class ThreadingContext2 {
// Also useful to suppress memory leak warnings in tests. // Also useful to suppress memory leak warnings in tests.
static void ThreadHostileInvalidate(); static void ThreadHostileInvalidate();
explicit ThreadingContext2(PrivateToken); // only called via `Get`. explicit ThreadingContext(PrivateToken); // only called via `Get`.
BoundedTopology topology; BoundedTopology topology;
Allocator2 allocator; Allocator allocator;
NestedPools pools; NestedPools pools;
}; };

View File

@ -383,7 +383,7 @@ TEST(ThreadingTest, BenchJoin) {
} }
}; };
NestedPools& pools = ThreadingContext2::Get().pools; NestedPools& pools = ThreadingContext::Get().pools;
// Use last package because the main thread has been pinned to it. // Use last package because the main thread has been pinned to it.
const size_t pkg_idx = pools.NumPackages() - 1; const size_t pkg_idx = pools.NumPackages() - 1;