Rename-only: remove Allocator2 etc suffixes now that refactoring is complete

PiperOrigin-RevId: 755397220
This commit is contained in:
Jan Wassenberg 2025-05-06 09:12:05 -07:00 committed by Copybara-Service
parent 8d0882b966
commit 275135d7e8
39 changed files with 215 additions and 216 deletions

View File

@ -66,14 +66,14 @@ hwy::ThreadPool& ThreadHostileGetPool() {
// can safely call `SetArgs` only once, because it would assert otherwise.
// This is preferable to calling `ThreadHostileInvalidate`, because we would
// repeat the topology initialization for every test.
if (!ThreadingContext2::IsInitialized()) {
if (!ThreadingContext::IsInitialized()) {
gcpp::ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 8;
threading_args.pin = Tristate::kFalse;
ThreadingContext2::SetArgs(threading_args);
ThreadingContext::SetArgs(threading_args);
}
return ThreadingContext2::Get().pools.Pool();
return ThreadingContext::Get().pools.Pool();
}
void TestMatMulVJP() {
@ -203,7 +203,7 @@ void TestEndToEnd() {
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim,
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);

View File

@ -45,9 +45,9 @@ TEST(OptimizeTest, GradientDescent) {
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.pin = Tristate::kFalse;
ThreadingContext2::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get());
const Allocator2& allocator = env.ctx.allocator;
ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext::Get());
const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool();
std::mt19937 gen(42);

View File

@ -67,7 +67,7 @@ template <typename T>
class WeightsWrapper {
public:
explicit WeightsWrapper(const ModelConfig& config) : weights_(config) {
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
weights_.AllocateForTest(owners_, pool);
}

View File

@ -35,7 +35,7 @@
namespace gcpp {
// Aborts if any keys differ, because then blobs are not comparable.
void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) {
void CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
if (reader1.Keys().size() != reader2.Keys().size()) {
HWY_ABORT("#keys mismatch: %zu vs %zu\n", reader1.Keys().size(),
reader2.Keys().size());
@ -49,13 +49,13 @@ void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) {
}
using KeyVec = std::vector<std::string>;
using RangeVec = std::vector<BlobRange2>;
using RangeVec = std::vector<BlobRange>;
RangeVec AllRanges(const KeyVec& keys, const BlobReader2& reader) {
RangeVec AllRanges(const KeyVec& keys, const BlobReader& reader) {
RangeVec ranges;
ranges.reserve(keys.size());
for (const std::string& key : keys) {
const BlobRange2* range = reader.Find(key);
const BlobRange* range = reader.Find(key);
if (!range) {
HWY_ABORT("Key %s not found, but was in KeyVec\n", key.c_str());
}
@ -82,7 +82,7 @@ void CompareRangeSizes(const KeyVec& keys, const RangeVec& ranges1,
// Total amount to allocate for all blobs.
size_t TotalBytes(const RangeVec& ranges) {
size_t total_bytes = 0;
for (const BlobRange2& range : ranges) {
for (const BlobRange& range : ranges) {
total_bytes += range.bytes;
}
return total_bytes;
@ -95,7 +95,7 @@ using BlobVec = std::vector<ByteSpan>; // in order of keys
// Assigns pointers within the single allocation and updates `pos`.
BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
BlobVec blobs;
for (const BlobRange2& range : ranges) {
for (const BlobRange& range : ranges) {
blobs.push_back(ByteSpan(all_blobs.get() + pos, range.bytes));
pos += range.bytes;
}
@ -104,7 +104,7 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
// Reads one set of blobs in parallel (helpful if in disk cache).
// Aborts on error.
void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs,
void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
hwy::ThreadPool& pool) {
HWY_ASSERT(reader.Keys().size() == blobs.size());
HWY_ASSERT(ranges.size() == blobs.size());
@ -116,7 +116,7 @@ void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs,
}
// Parallelizes ReadBlobs across (two) packages, if available.
void ReadBothBlobs(BlobReader2& reader1, BlobReader2& reader2,
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
const RangeVec& ranges1, const RangeVec& ranges2,
size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
NestedPools& pools) {
@ -215,8 +215,8 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
// Compares two sbs files, including blob order.
void ReadAndCompareBlobs(const char* path1, const char* path2) {
const Tristate map = Tristate::kFalse;
std::unique_ptr<BlobReader2> reader1 = BlobReader2::Make(Path(path1), map);
std::unique_ptr<BlobReader2> reader2 = BlobReader2::Make(Path(path2), map);
std::unique_ptr<BlobReader> reader1 = BlobReader::Make(Path(path1), map);
std::unique_ptr<BlobReader> reader2 = BlobReader::Make(Path(path2), map);
if (!reader1 || !reader2) {
HWY_ABORT(
"Failed to create readers for files %s %s, see error messages above.\n",
@ -235,7 +235,7 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
NestedPools& pools = ThreadingContext2::Get().pools;
NestedPools& pools = ThreadingContext::Get().pools;
ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1,
blobs2, pools);

View File

@ -94,7 +94,7 @@ static_assert(sizeof(Header) == 16);
// Additional data may be added only inside new blobs. Changes to the blob
// contents or type should be handled by renaming keys.
//
// This class is for internal use by `BlobReader2` and `BlobWriter2`. Its
// This class is for internal use by `BlobReader` and `BlobWriter`. Its
// interface is more low-level: fixed-size keys instead of strings.
class BlobStore {
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
@ -182,7 +182,7 @@ class BlobStore {
padded_dir_bytes - 2 * num_blobs * kU128Bytes);
// We already zero-initialized the directory padding;
// `BlobWriter2::WriteAll` takes care of padding after each blob via an
// `BlobWriter::WriteAll` takes care of padding after each blob via an
// additional I/O.
for (size_t i = 0; i < num_blobs; ++i) {
HWY_ASSERT(blobs[i].data() != nullptr);
@ -242,12 +242,12 @@ class BlobStore {
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
const size_t key_idx = 0; // not actually associated with a key/blob
writes.emplace_back(
BlobRange2{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
// members are const and BlobIO2 requires non-const pointers, and they
// are not modified by file writes.
const_cast<Header*>(&header_));
writes.emplace_back(
BlobRange2{.offset = sizeof(header_),
BlobRange{.offset = sizeof(header_),
.bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_),
.key_idx = key_idx},
const_cast<hwy::uint128_t*>(directory_.data()));
@ -289,8 +289,8 @@ class BlobStore {
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
}; // BlobStore
BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, BlobReader2::Mode mode)
BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, BlobReader::Mode mode)
: file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) {
HWY_ASSERT(file_ && file_bytes_ != 0);
@ -306,12 +306,12 @@ BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
size_t bytes;
bs.GetRange(key_idx, offset, bytes);
ranges_.emplace_back(
BlobRange2{.offset = offset, .bytes = bytes, .key_idx = key_idx});
BlobRange{.offset = offset, .bytes = bytes, .key_idx = key_idx});
key_idx_for_key_[keys_[key_idx]] = key_idx;
}
if (mode_ == Mode::kMap) {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const Allocator& allocator = ThreadingContext::Get().allocator;
// Verify `kEndAlign` is an upper bound on the page size.
if (kEndAlign % allocator.BasePageBytes() != 0) {
HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.",
@ -338,12 +338,12 @@ BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
}
}
void BlobReader2::Enqueue(const BlobRange2& range, void* data) {
void BlobReader::Enqueue(const BlobRange& range, void* data) {
// Debug-only because there may be many I/O requests (per row).
if constexpr (HWY_IS_DEBUG_BUILD) {
HWY_DASSERT(!IsMapped());
HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr);
const BlobRange2& blob_range = Range(range.key_idx);
const BlobRange& blob_range = Range(range.key_idx);
HWY_DASSERT(blob_range.End() <= file_bytes_);
if (range.End() > blob_range.End()) {
HWY_ABORT(
@ -362,15 +362,15 @@ void BlobReader2::Enqueue(const BlobRange2& range, void* data) {
// TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX.
// - O_DIRECT seems undesirable because we do want to use the OS cache
// between consecutive runs.
void BlobReader2::ReadAll(hwy::ThreadPool& pool) const {
void BlobReader::ReadAll(hwy::ThreadPool& pool) const {
PROFILER_ZONE("Startup.ReadAll");
HWY_ASSERT(!IsMapped());
// >5x speedup from parallel reads when cached.
pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) {
const BlobRange2& range = requests_[i].range;
const BlobRange& range = requests_[i].range;
const uint64_t end = range.End();
const std::string& key = keys_[range.key_idx];
const BlobRange2& blob_range = Range(range.key_idx);
const BlobRange& blob_range = Range(range.key_idx);
HWY_ASSERT(blob_range.End() <= file_bytes_);
if (end > blob_range.End()) {
HWY_ABORT(
@ -387,11 +387,11 @@ void BlobReader2::ReadAll(hwy::ThreadPool& pool) const {
}
// Decides whether to read or map the file.
static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
static BlobReader::Mode ChooseMode(uint64_t file_mib, Tristate map) {
const Allocator& allocator = ThreadingContext::Get().allocator;
// User has explicitly requested a map or read via args.
if (map == Tristate::kTrue) return BlobReader2::Mode::kMap;
if (map == Tristate::kFalse) return BlobReader2::Mode::kRead;
if (map == Tristate::kTrue) return BlobReader::Mode::kMap;
if (map == Tristate::kFalse) return BlobReader::Mode::kRead;
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
// because idle memory is used as cache, so do not use it to decide.
const size_t total_mib = allocator.TotalMiB();
@ -400,13 +400,13 @@ static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) {
static_cast<size_t>(file_mib), total_mib);
}
// Large fraction of total.
if (file_mib >= total_mib / 3) return BlobReader2::Mode::kMap;
if (file_mib >= total_mib / 3) return BlobReader::Mode::kMap;
// Big enough that even parallel loading wouldn't be quick.
if (file_mib > 50 * 1024) return BlobReader2::Mode::kMap;
return BlobReader2::Mode::kRead;
if (file_mib > 50 * 1024) return BlobReader::Mode::kMap;
return BlobReader::Mode::kRead;
}
std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path,
std::unique_ptr<BlobReader> BlobReader::Make(const Path& blob_path,
const Tristate map) {
if (blob_path.Empty()) HWY_ABORT("No --weights specified.");
std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r");
@ -417,10 +417,10 @@ std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path,
// Even if `kMap`, read the directory via the `kRead` mode for simplicity.
BlobStore bs(*file);
if (!bs.IsValid(file_bytes)) {
return std::unique_ptr<BlobReader2>(); // IsValid already printed a warning
return std::unique_ptr<BlobReader>(); // IsValid already printed a warning
}
return std::unique_ptr<BlobReader2>(new BlobReader2(
return std::unique_ptr<BlobReader>(new BlobReader(
std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map)));
}
@ -434,14 +434,13 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
for (; offset <= end - kChunkBytes;
offset += kChunkBytes, data += kChunkBytes) {
writes.emplace_back(
BlobRange2{
.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
BlobRange{.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
data);
}
}
if (offset != end) {
writes.emplace_back(
BlobRange2{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
data);
}
}
@ -472,7 +471,7 @@ static void EnqueueWritesForBlobs(const BlobStore& bs,
if (padding != 0) {
HWY_ASSERT(padding <= kBlobAlign);
writes.emplace_back(
BlobRange2{
BlobRange{
.offset = offset + bytes, .bytes = padding, .key_idx = key_idx},
const_cast<uint8_t*>(kZeros));
}
@ -484,19 +483,19 @@ static void EnqueueWritesForBlobs(const BlobStore& bs,
// remain alive until the last I/O is done.
zeros.resize(padding);
writes.emplace_back(
BlobRange2{.offset = file_end, .bytes = padding, .key_idx = 0},
BlobRange{.offset = file_end, .bytes = padding, .key_idx = 0},
zeros.data());
}
}
void BlobWriter2::Add(const std::string& key, const void* data, size_t bytes) {
void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
HWY_ASSERT(data != nullptr);
HWY_ASSERT(bytes != 0);
keys_.push_back(KeyFromString(key.c_str()));
blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes);
}
void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
const size_t num_blobs = keys_.size();
HWY_ASSERT(num_blobs != 0);
HWY_ASSERT(num_blobs == blobs_.size());
@ -516,7 +515,7 @@ void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
pool.Run(0, writes.size(),
[this, &file, &writes](uint64_t i, size_t /*thread*/) {
const BlobRange2& range = writes[i].range;
const BlobRange& range = writes[i].range;
if (!file->Write(writes[i].data, range.bytes, range.offset)) {
const std::string& key = StringFromKey(keys_[range.key_idx]);

View File

@ -35,20 +35,20 @@
namespace gcpp {
// One blob's extents within the file.
struct BlobRange2 {
struct BlobRange {
uint64_t End() const { return offset + bytes; }
uint64_t offset = 0;
size_t bytes = 0; // We check blobs are not zero-sized.
// Index within `BlobReader2::Keys()` for error reporting.
// Index within `BlobReader::Keys()` for error reporting.
size_t key_idx;
};
// A read or write I/O request, each serviced by one thread in a pool.
struct BlobIO2 {
BlobIO2(BlobRange2 range, void* data) : range(range), data(data) {}
BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
BlobRange2 range;
BlobRange range;
void* data; // Modified only if a read request. Read-only for writes.
};
@ -59,7 +59,7 @@ class BlobStore;
// Thread-safe: it is safe to concurrently call all methods except `Enqueue`,
// because they are const.
// TODO(janwas): split into header and reader/mapper classes.
class BlobReader2 {
class BlobReader {
public:
// Parallel I/O into allocated memory, or mapped view of file. The latter is
// better when the file is huge, but page faults add noise to measurements.
@ -67,26 +67,26 @@ class BlobReader2 {
// Acquires ownership of `file` (which must be non-null) and reads its header.
// Factory function instead of ctor because this can fail (return null).
static std::unique_ptr<BlobReader2> Make(const Path& blob_path,
static std::unique_ptr<BlobReader> Make(const Path& blob_path,
Tristate map = Tristate::kDefault);
~BlobReader2() = default;
~BlobReader() = default;
// Returns true if the mode passed to ctor was `kMap` and mapping succeeded.
bool IsMapped() const { return mode_ == Mode::kMap; }
const std::vector<std::string>& Keys() const { return keys_; }
const BlobRange2& Range(size_t key_idx) const {
const BlobRange& Range(size_t key_idx) const {
HWY_ASSERT(key_idx < keys_.size());
return ranges_[key_idx];
}
// Returns nullptr if not found. O(1).
const BlobRange2* Find(const std::string& key) const {
const BlobRange* Find(const std::string& key) const {
auto it = key_idx_for_key_.find(key);
if (it == key_idx_for_key_.end()) return nullptr;
const BlobRange2& range = Range(it->second);
const BlobRange& range = Range(it->second);
HWY_ASSERT(range.offset != 0 && range.bytes != 0);
HWY_ASSERT(range.End() <= file_bytes_);
return &range;
@ -95,7 +95,7 @@ class BlobReader2 {
// Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that
// everything else except `CallWithSpan` is in units of bytes.
template <typename T>
hwy::Span<const T> MappedSpan(const BlobRange2& range) const {
hwy::Span<const T> MappedSpan(const BlobRange& range) const {
HWY_ASSERT(IsMapped());
HWY_ASSERT(range.bytes % sizeof(T) == 0);
return hwy::Span<const T>(
@ -108,7 +108,7 @@ class BlobReader2 {
// which an aligned allocation is unnecessary.
template <typename T, class Func>
bool CallWithSpan(const std::string& key, const Func& func) const {
const BlobRange2* range = Find(key);
const BlobRange* range = Find(key);
if (!range) {
HWY_WARN("Blob %s not found, sizeof T=%zu", key.c_str(), sizeof(T));
return false;
@ -134,7 +134,7 @@ class BlobReader2 {
// The following methods must only be called if `!IsMapped()`.
// Enqueues a BlobIO2 for `ReadAll` to execute.
void Enqueue(const BlobRange2& range, void* data);
void Enqueue(const BlobRange& range, void* data);
// Reads in parallel all enqueued requests to the specified destinations.
// Aborts on error.
@ -142,7 +142,7 @@ class BlobReader2 {
private:
// Only for use by `Make`.
BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
const BlobStore& bs, Mode mode);
const std::unique_ptr<File> file_;
@ -150,7 +150,7 @@ class BlobReader2 {
Mode mode_;
std::vector<std::string> keys_;
std::vector<BlobRange2> ranges_;
std::vector<BlobRange> ranges_;
std::unordered_map<std::string, size_t> key_idx_for_key_;
MapPtr mapped_; // only if `kMap`
@ -160,7 +160,7 @@ class BlobReader2 {
// Collects references to blobs and writes them all at once with parallel I/O.
// Thread-compatible: independent instances can be used concurrently, but it
// does not make sense to call the methods concurrently.
class BlobWriter2 {
class BlobWriter {
public:
void Add(const std::string& key, const void* data, size_t bytes);

View File

@ -37,7 +37,7 @@ class BlobStoreTest : public testing::Test {};
#endif
void TestWithMapped(Tristate map) {
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
@ -51,7 +51,7 @@ void TestWithMapped(Tristate map) {
const std::string keyA("0123456789abcdef"); // max 16 characters
const std::string keyB("q");
BlobWriter2 writer;
BlobWriter writer;
writer.Add(keyA, "DATA", 5);
writer.Add(keyB, buffer.data(), sizeof(buffer));
writer.WriteAll(pool, path);
@ -59,14 +59,14 @@ void TestWithMapped(Tristate map) {
std::fill(buffer.begin(), buffer.end(), 0);
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map);
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
HWY_ASSERT(reader);
HWY_ASSERT_EQ(reader->Keys().size(), 2);
HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str());
HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str());
const BlobRange2* range = reader->Find(keyA);
const BlobRange* range = reader->Find(keyA);
HWY_ASSERT(range);
const uint64_t offsetA = range->offset;
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
@ -80,9 +80,9 @@ void TestWithMapped(Tristate map) {
if (!reader->IsMapped()) {
char str[5];
reader->Enqueue(
BlobRange2{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
BlobRange{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
reader->Enqueue(
BlobRange2{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
BlobRange{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
buffer.data());
reader->ReadAll(pool);
HWY_ASSERT_STRING_EQ("DATA", str);
@ -111,7 +111,7 @@ TEST(BlobStoreTest, TestReadWrite) {
// Ensures padding works for any number of random-sized blobs.
TEST(BlobStoreTest, TestNumBlobs) {
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
hwy::RandomState rng;
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
@ -121,7 +121,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
HWY_ASSERT(fd > 0);
const Path path(path_str);
BlobWriter2 writer;
BlobWriter writer;
std::vector<std::string> keys;
keys.reserve(num_blobs);
std::vector<std::vector<uint8_t>> blobs;
@ -144,13 +144,13 @@ TEST(BlobStoreTest, TestNumBlobs) {
writer.WriteAll(pool, path);
const Tristate map = Tristate::kFalse;
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map);
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
HWY_ASSERT(reader);
HWY_ASSERT_EQ(reader->Keys().size(), num_blobs);
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(),
std::to_string(i).c_str());
const BlobRange2* range = reader->Find(keys[i]);
const BlobRange* range = reader->Find(keys[i]);
HWY_ASSERT(range);
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
HWY_ASSERT(reader->CallWithSpan<uint8_t>(

View File

@ -22,7 +22,7 @@
#include <string>
#include <vector>
#include "compression/blob_store.h" // BlobWriter2
#include "compression/blob_store.h" // BlobWriter
#include "compression/compress.h" // ScaleWeights
#include "compression/io.h" // Path
#include "gemma/configs.h" // ModelConfig
@ -88,7 +88,7 @@ class SbsWriterImpl : public ISbsWriter {
}
public:
SbsWriterImpl() : pool_(ThreadingContext2::Get().pools.Pool()) {}
SbsWriterImpl() : pool_(ThreadingContext::Get().pools.Pool()) {}
void Insert(const char* name, F32Span weights, Type type,
const TensorInfo& tensor_info) override {
@ -123,7 +123,7 @@ class SbsWriterImpl : public ISbsWriter {
hwy::ThreadPool& pool_;
MatOwners mat_owners_;
CompressWorkingSet working_set_;
BlobWriter2 writer_;
BlobWriter writer_;
std::vector<uint32_t> serialized_mat_ptrs_;
};
@ -141,7 +141,7 @@ HWY_EXPORT(NewSbsWriter);
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
SbsReader::SbsReader(const std::string& path)
: reader_(gcpp::BlobReader2::Make(Path(path))), model_(*reader_) {}
: reader_(gcpp::BlobReader::Make(Path(path))), model_(*reader_) {}
} // namespace gcpp
#endif // HWY_ONCE

View File

@ -77,8 +77,8 @@ class SbsReader {
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
private:
std::unique_ptr<gcpp::BlobReader2> reader_;
gcpp::ModelStore2 model_;
std::unique_ptr<gcpp::BlobReader> reader_;
gcpp::ModelStore model_;
};
} // namespace gcpp

View File

@ -240,7 +240,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100);
const ThreadingContext2& ctx = ThreadingContext2::Get();
const ThreadingContext& ctx = ThreadingContext::Get();
fprintf(stderr,
"Date & Time : %s" // dt includes \n

View File

@ -50,7 +50,7 @@ class GemmaEnv {
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
const InferenceArgs& inference);
// Avoid memory leaks in test.
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
~GemmaEnv() { ThreadingContext::ThreadHostileInvalidate(); }
MatMulEnv& Env() { return env_; }

View File

@ -72,7 +72,7 @@ struct Activations {
size_t cache_pos_size = 0;
void Allocate(size_t batch_size, MatMulEnv* env) {
const Allocator2& allocator = env->ctx.allocator;
const Allocator& allocator = env->ctx.allocator;
post_qk = layer_config.post_qk;
const size_t model_dim = weights_config.model_dim;

View File

@ -561,7 +561,7 @@ class GemmaAttention {
const LayerWeightsPtrs<T>& layer_weights_;
const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_;
const Allocator2& allocator_;
const Allocator& allocator_;
hwy::ThreadPool& pool_;
};
@ -749,7 +749,7 @@ class VitAttention {
Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_;
const LayerConfig& layer_config_;
const Allocator2& allocator_;
const Allocator& allocator_;
hwy::ThreadPool& pool_;
};
@ -789,7 +789,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator2& allocator = activations.env->ctx.allocator;
const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
@ -847,7 +847,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
const Allocator2& allocator = activations.env->ctx.allocator;
const Allocator& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
@ -1416,7 +1416,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
//
// `kv_caches` is for the batch, size must match `queries_prompt`.
template <typename T>
void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights,
void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
Activations& activations, const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
const QueriesPos& queries_pos_in,
@ -1508,7 +1508,7 @@ void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights,
}
template <typename T>
void GenerateSingleT(const ModelStore2& model,
void GenerateSingleT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
@ -1532,7 +1532,7 @@ void GenerateSingleT(const ModelStore2& model,
}
template <typename T>
void GenerateBatchT(const ModelStore2& model,
void GenerateBatchT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt,
@ -1573,7 +1573,7 @@ void GenerateBatchT(const ModelStore2& model,
}
template <typename T>
void GenerateImageTokensT(const ModelStore2& model,
void GenerateImageTokensT(const ModelStore& model,
const ModelWeightsPtrs<T>& weights,
const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens,
@ -1599,7 +1599,7 @@ void GenerateImageTokensT(const ModelStore2& model,
// These are extern functions defined by instantiations/*.cc, which include this
// 'header' after defining `GEMMA_TYPE`.
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
TimingInfo& timing_info) {
@ -1609,7 +1609,7 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers)
}
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config,
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
@ -1620,7 +1620,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers)
}
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
const RuntimeConfig& runtime_config, const Image& image,
ImageTokens& image_tokens, MatMulEnv* env) {
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)

View File

@ -47,13 +47,13 @@ namespace gcpp {
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
// Placeholder for internal init, do not modify.
ThreadingContext2::SetArgs(threading_args);
return MatMulEnv(ThreadingContext2::Get());
ThreadingContext::SetArgs(threading_args);
return MatMulEnv(ThreadingContext::Get());
}
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
: env_(env),
reader_(BlobReader2::Make(loader.weights, loader.map)),
reader_(BlobReader::Make(loader.weights, loader.map)),
model_(*reader_, loader.tokenizer, loader.wrapping),
weights_(model_.Config().weight),
chat_template_(model_.Tokenizer(), model_.Config().model) {
@ -74,7 +74,7 @@ Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer,
Gemma::~Gemma() = default;
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
BlobWriter2 writer;
BlobWriter writer;
const std::vector<uint32_t> serialized_mat_ptrs =
weights_.AddTensorDataToWriter(writer);
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
@ -90,17 +90,17 @@ void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
// instead of `WeightsPtrs<T>`.
#define GEMMA_DECLARE(WEIGHT_TYPE) \
extern void GenerateSingle( \
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const PromptTokens& prompt, \
size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \
TimingInfo& timing_info); \
extern void GenerateBatch( \
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
extern void GenerateImageTokens( \
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
const RuntimeConfig& runtime_config, const Image& image, \
ImageTokens& image_tokens, MatMulEnv* env);
GEMMA_DECLARE(float)

View File

@ -160,8 +160,8 @@ class Gemma {
private:
MatMulEnv& env_;
std::unique_ptr<BlobReader2> reader_; // null for second ctor
ModelStore2 model_;
std::unique_ptr<BlobReader> reader_; // null for second ctor
ModelStore model_;
WeightsOwner weights_;
GemmaChatTemplate chat_template_;
};

View File

@ -56,7 +56,7 @@ static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
// Returns the serialized tokenizer (std::string is required for proto).
// Reads it from a blob or from a separate file if pre-2025.
static std::string ReadTokenizer(BlobReader2& reader,
static std::string ReadTokenizer(BlobReader& reader,
const Path& tokenizer_path) {
std::string tokenizer;
// Check prevents `CallWithSpan` from printing a warning.
@ -107,7 +107,7 @@ class TypePrefix {
}
}
TypePrefix(const KeyVec& keys, const BlobReader2& reader) {
TypePrefix(const KeyVec& keys, const BlobReader& reader) {
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
const std::string& key = keys[key_idx];
const Type type = TypeFromChar(key[0]);
@ -200,7 +200,7 @@ static int DeduceLayerTypes(const KeyVec& keys) {
// `wrapping_override` is forwarded from the command line. For pre-2025 files
// without `ModelConfig`, it is the only way to force PT.
static ModelConfig ReadOrDeduceConfig(BlobReader2& reader,
static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
Tristate wrapping_override) {
const TypePrefix type_prefix(reader.Keys(), reader);
Type deduced_weight = Type::kUnknown;
@ -244,7 +244,7 @@ static ModelConfig ReadOrDeduceConfig(BlobReader2& reader,
ChooseWrapping(config.model, wrapping_override));
}
static std::vector<float> ReadScales(BlobReader2& reader,
static std::vector<float> ReadScales(BlobReader& reader,
const ModelConfig& config) {
std::vector<float> scales;
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
@ -260,7 +260,7 @@ static std::vector<float> ReadScales(BlobReader2& reader,
}
// Single-file format: reads `MatPtr` from the blob; returns false if not found.
bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
bool ModelStore::ReadMatPtrs(BlobReader& reader) {
// Check first to prevent `CallWithSpan` from printing a warning.
if (!reader.Find(kMatPtrsName)) return false;
@ -282,7 +282,7 @@ bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
// Retrieve actual key index because a writer may have written other
// blobs before the tensor data.
const BlobRange2* range = reader.Find(mat.Name());
const BlobRange* range = reader.Find(mat.Name());
HWY_ASSERT(range);
const size_t key_idx = range->key_idx;
AddMatPtr(key_idx, mat);
@ -302,7 +302,7 @@ bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
}
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
void ModelStore2::CreateMatPtrs(BlobReader2& reader) {
void ModelStore::CreateMatPtrs(BlobReader& reader) {
const TensorInfoRegistry tensors(config_);
const KeyVec& keys = reader.Keys();
@ -329,7 +329,7 @@ void ModelStore2::CreateMatPtrs(BlobReader2& reader) {
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
}
ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path,
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
Tristate wrapping)
: config_(ReadOrDeduceConfig(reader, wrapping)),
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
@ -348,12 +348,12 @@ ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path,
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
}
ModelStore2::~ModelStore2() {
ModelStore::~ModelStore() {
// Sanity check: ensure all scales were consumed.
HWY_ASSERT(scales_consumed_ == scales_.size());
}
const MatPtr* ModelStore2::FindMat(const char* name) const {
const MatPtr* ModelStore::FindMat(const char* name) const {
auto it = mat_idx_for_name_.find(name);
if (it == mat_idx_for_name_.end()) return nullptr;
const size_t mat_idx = it->second;
@ -362,7 +362,7 @@ const MatPtr* ModelStore2::FindMat(const char* name) const {
return file_mat;
}
bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
const MatPtr* file_mat = FindMat(mat.Name());
if (!file_mat) return false;
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
@ -390,14 +390,14 @@ bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
}
static void AddBlob(const char* name, const std::vector<uint32_t>& data,
BlobWriter2& writer) {
BlobWriter& writer) {
HWY_ASSERT(!data.empty());
writer.Add(name, data.data(), data.size() * sizeof(data[0]));
}
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter2& writer, hwy::ThreadPool& pool,
BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path) {
HWY_ASSERT(config.model != Model::UNKNOWN);
HWY_ASSERT(config.weight != Type::kUnknown);

View File

@ -48,16 +48,16 @@ namespace gcpp {
// tokenizer in a separate file, encoded tensor type in a prefix of the blob
// name, and had a blob for tensor scaling factors. We still support reading
// both, but only write single-file format.
class ModelStore2 {
class ModelStore {
public:
// Reads from file(s) or aborts on error. The latter two arguments are only
// used for pre-2025 files.
ModelStore2(BlobReader2& reader, const Path& tokenizer_path = Path(),
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
Tristate wrapping = Tristate::kDefault);
// For optimize_test.cc.
ModelStore2(const ModelConfig& config, GemmaTokenizer&& tokenizer)
ModelStore(const ModelConfig& config, GemmaTokenizer&& tokenizer)
: config_(config), tokenizer_(std::move(tokenizer)) {}
~ModelStore2();
~ModelStore();
const ModelConfig& Config() const {
HWY_ASSERT(config_.model != Model::UNKNOWN);
@ -72,7 +72,7 @@ class ModelStore2 {
// Returns false if `mat` is not available for loading, otherwise updates
// `mat` with metadata from the file and sets `key_idx` for use by
// `BlobReader2`. Called via `ReadOrAllocate` in `weights.cc`.
// `BlobReader`. Called via `ReadOrAllocate` in `weights.cc`.
bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const;
private:
@ -83,15 +83,15 @@ class ModelStore2 {
key_idx_.push_back(key_idx);
}
bool ReadMatPtrs(BlobReader2& reader);
void CreateMatPtrs(BlobReader2& reader); // Aborts on error.
bool ReadMatPtrs(BlobReader& reader);
void CreateMatPtrs(BlobReader& reader); // Aborts on error.
ModelConfig config_;
GemmaTokenizer tokenizer_;
// All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`.
std::vector<MatPtr> mat_ptrs_;
// For each of `mat_ptrs_`, the index within `BlobReader2::Keys()`. This is
// For each of `mat_ptrs_`, the index within `BlobReader::Keys()`. This is
// not necessarily iota because some blobs are not tensors, and callers may
// have added blobs before ours.
std::vector<size_t> key_idx_;
@ -108,7 +108,7 @@ class ModelStore2 {
// produces a single BlobStore file holding everything required for inference.
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
const std::vector<uint32_t>& serialized_mat_ptrs,
BlobWriter2& writer, hwy::ThreadPool& pool,
BlobWriter& writer, hwy::ThreadPool& pool,
const Path& path);
} // namespace gcpp

View File

@ -84,8 +84,8 @@ void LayerWeightsPtrs<NuqStream>::Reshape() {
}
// Aborts on error.
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
const std::vector<BlobRange2>& ranges,
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
const std::vector<BlobRange>& ranges,
MatOwners& mat_owners, const MatPadding padding,
hwy::ThreadPool& pool) {
HWY_ASSERT(mats.size() == ranges.size());
@ -121,7 +121,7 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
uint8_t* row = mats[i]->RowT<uint8_t>(0);
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
reader.Enqueue(BlobRange2{.offset = offset,
reader.Enqueue(BlobRange{.offset = offset,
.bytes = file_bytes_per_row,
.key_idx = ranges[i].key_idx},
row);
@ -134,11 +134,11 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
reader.ReadAll(pool);
}
void WeightsOwner::ReadOrAllocate(const ModelStore2& model, BlobReader2& reader,
void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool) {
// List of tensors to read/map, and where from.
std::vector<MatPtr*> mats;
std::vector<BlobRange2> ranges;
std::vector<BlobRange> ranges;
// Padding is inserted when reading row by row, except for NUQ tensors.
const MatPadding padding = MatPadding::kOdd;
@ -244,7 +244,7 @@ void WeightsOwner::Reshape(hwy::ThreadPool& pool) {
}
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
BlobWriter2& writer) const {
BlobWriter& writer) const {
std::vector<uint32_t> serialized_mat_ptrs;
CallT([&](const auto& weights) {
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {

View File

@ -25,7 +25,7 @@
#include <string>
#include <vector>
#include "compression/blob_store.h" // BlobWriter2
#include "compression/blob_store.h" // BlobWriter
#include "compression/shared.h" // IsF32
#include "gemma/configs.h" // ModelConfig
#include "gemma/model_store.h" // ModelStore
@ -519,7 +519,7 @@ class WeightsOwner {
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`,
// allocates memory and reshapes. Aborts on error.
void ReadOrAllocate(const ModelStore2& model, BlobReader2& reader,
void ReadOrAllocate(const ModelStore& model, BlobReader& reader,
hwy::ThreadPool& pool);
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
@ -541,7 +541,7 @@ class WeightsOwner {
// For writers:
// Adds one blob for each tensor's data and returns all serialized MatPtr.
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter2& writer) const;
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
// For backprop/:

View File

@ -79,7 +79,7 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// M = A rows, K = A cols, N = C cols.
template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Allocator2& allocator = env.ctx.allocator;
const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n");
@ -160,7 +160,7 @@ void BenchAllMatMul() {
return;
}
ThreadingContext2& ctx = ThreadingContext2::Get();
ThreadingContext& ctx = ThreadingContext::Get();
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
ctx.pools.PinString());

View File

@ -999,7 +999,7 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot
const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator;
const Allocator& allocator = gcpp::ThreadingContext::Get().allocator;
CompressWorkingSet work;
std::mt19937 rng;
rng.seed(12345);
@ -1099,14 +1099,14 @@ void TestAllDot() {
constexpr size_t kMaxWorkers = 15;
// Reset with cap on workers because we only support `kMaxWorkers`.
ThreadingContext2::ThreadHostileInvalidate();
ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext2::SetArgs(threading_args);
ThreadingContext2& ctx = ThreadingContext2::Get();
const Allocator2& allocator = ctx.allocator;
ThreadingContext::SetArgs(threading_args);
ThreadingContext& ctx = ThreadingContext::Get();
const Allocator& allocator = ctx.allocator;
{ // ensure no profiler zones are active
const hn::ScalableTag<float> df;

View File

@ -909,7 +909,7 @@ class MMPerPackage {
static constexpr size_t B_stride_max_ =
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
static constexpr size_t B_storage_max_ =
kNR * B_stride_max_ + Allocator2::MaxQuantum<BF16>();
kNR * B_stride_max_ + Allocator::MaxQuantum<BF16>();
// Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing.
@ -1175,7 +1175,7 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`.
template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
const Allocator2& allocator = args_.env->ctx.allocator;
const Allocator& allocator = args_.env->ctx.allocator;
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) {
@ -1316,7 +1316,7 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
const Allocator2& allocator = env.ctx.allocator;
const Allocator& allocator = env.ctx.allocator;
const size_t M = A.Extents().rows;
const size_t K = A.Extents().cols;
const size_t N = B.Extents().rows;

View File

@ -60,7 +60,7 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables.
class GenerateCandidates {
public:
GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N,
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
: allocator_(allocator),
@ -352,7 +352,7 @@ class GenerateCandidates {
}
}
const Allocator2& allocator_;
const Allocator& allocator_;
const size_t M_;
const size_t K_;
const size_t N_;
@ -372,7 +372,7 @@ class GenerateCandidates {
} // namespace
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
@ -384,7 +384,7 @@ std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package
// rows for that.
static size_t NPMultiple(const Allocator2& allocator, size_t N,
static size_t NPMultiple(const Allocator& allocator, size_t N,
size_t sizeof_TC, size_t nr, size_t num_packages) {
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
@ -417,7 +417,7 @@ IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
}
MatMulEnv::MatMulEnv(ThreadingContext2& ctx)
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);

View File

@ -50,7 +50,7 @@ class MMParallel {
static constexpr size_t kMaxPackages = 4;
// `ctx` must outlive this object.
MMParallel(ThreadingContext2& ctx) : ctx_(ctx) {
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
}
@ -164,11 +164,11 @@ class MMParallel {
}
private:
ThreadingContext2& ctx_;
ThreadingContext& ctx_;
};
template <typename TC> // BF16/float for C, double for partial
void BindC(const Allocator2& allocator, size_t M, const RowPtr<TC>& C,
void BindC(const Allocator& allocator, size_t M, const RowPtr<TC>& C,
MMParallel& parallel) {
if (!allocator.ShouldBind()) return;
@ -207,7 +207,7 @@ class MMStorage {
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
MMStorage(const Allocator2& allocator, MMParallel& parallel)
MMStorage(const Allocator& allocator, MMParallel& parallel)
// Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity.
@ -236,7 +236,7 @@ class MMStorage {
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
// non-const, because `RowPtr` requires a non-const pointer.
RowPtrBF A(const Allocator2& allocator, size_t pkg_idx,
RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
const Extents2D& extents) {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
@ -430,7 +430,7 @@ class MMConfig {
static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
@ -561,7 +561,7 @@ class MMKeys {
}
// Must only be called if not already present in `Keys()`.
void Append(Key key, const Allocator2& allocator) {
void Append(Key key, const Allocator& allocator) {
// Dynamic allocation because the test checks many more dimensions than
// would be reasonable to pre-allocate. DIY for alignment and padding.
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
@ -608,9 +608,9 @@ struct MMPerKey {
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`.
struct MatMulEnv {
explicit MatMulEnv(ThreadingContext2& ctx);
explicit MatMulEnv(ThreadingContext& ctx);
ThreadingContext2& ctx;
ThreadingContext& ctx;
bool have_timer_stop = false;
// Whether `MMCandidates()` should print the set of parameters.
@ -753,7 +753,7 @@ ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
}
template <typename TB>
void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC,
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
const ConstMat<TB>& B, MMParallel& parallel) {
if (!allocator.ShouldBind()) return;

View File

@ -86,7 +86,7 @@ float MaxAbs(const RowVectorBatch<float>& a) {
template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const Allocator& allocator = ThreadingContext::Get().allocator;
const hn::ScalableTag<float> df;
const size_t cols = A.extents.cols;
const size_t B_rows = B.extents.rows;
@ -210,7 +210,7 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) {
const Allocator2& allocator = env.ctx.allocator;
const Allocator& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
@ -259,12 +259,12 @@ void TestTiny() {
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
ThreadingContext2::ThreadHostileInvalidate();
ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
threading_args.max_packages = max_packages;
ThreadingContext2::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get());
ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext::Get());
NestedPools& pools = env.ctx.pools;
#if GEMMA_DISABLE_TOPOLOGY
@ -296,11 +296,11 @@ void TestAllMatMul() {
return;
}
ThreadingContext2::ThreadHostileInvalidate();
ThreadingContext::ThreadHostileInvalidate();
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
ThreadingContext2::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get());
ThreadingContext::SetArgs(threading_args);
MatMulEnv env(ThreadingContext::Get());
NestedPools& pools = env.ctx.pools;
pools.MaybeStartSpinning(threading_args.spin);

View File

@ -808,7 +808,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Each output row is the average of a 4x4 block of input rows
template <typename T>
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const Allocator& allocator = ThreadingContext::Get().allocator;
const Extents2D extents = input.Extents();
// Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows

View File

@ -27,7 +27,7 @@
namespace gcpp {
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
const Allocator2& allocator, size_t qkv_dim, bool half_rope,
const Allocator& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));

View File

@ -386,7 +386,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
}
void TestRopeAndMulBy() {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const Allocator& allocator = ThreadingContext::Get().allocator;
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
ChooseWrapping(Model::GEMMA2_9B));

View File

@ -47,7 +47,7 @@ class PaliGemmaTest : public ::testing::Test {
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetGemma(), nullptr);
const Allocator2& allocator = s_env->Env().ctx.allocator;
const Allocator& allocator = s_env->Env().ctx.allocator;
Gemma& gemma = *(s_env->GetGemma());
image_tokens_ = ImageTokens(
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,

View File

@ -168,7 +168,7 @@ class GemmaModel {
void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) {
const gcpp::Gemma& gemma = *gemma_.GetGemma();
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator;
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
throw std::invalid_argument("Not a PaliGemma model.");

View File

@ -130,7 +130,7 @@ size_t DetectTotalMiB(size_t page_bytes) {
} // namespace
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
line_bytes_ = DetectLineBytes();
vector_bytes_ = hwy::VectorBytes();
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
@ -180,7 +180,7 @@ Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1;
}
size_t Allocator2::FreeMiB() const {
size_t Allocator::FreeMiB() const {
#if HWY_OS_LINUX
const long ret = sysconf(_SC_AVPHYS_PAGES); // NOLINT(runtime/int)
HWY_ASSERT(ret != -1);
@ -201,7 +201,7 @@ size_t Allocator2::FreeMiB() const {
#endif
}
AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const {
AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
// defends against 2K aliasing.
if (!should_bind_) {
@ -296,7 +296,7 @@ size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
return num_busy;
}
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) const {
HWY_DASSERT(should_bind_);
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
@ -353,7 +353,7 @@ bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
}
#else
bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; }
bool Allocator::BindMemory(void*, size_t, size_t) const { return false; }
#endif // GEMMA_BIND && HWY_OS_LINUX
} // namespace gcpp

View File

@ -78,14 +78,14 @@ template <typename T>
using AlignedClassPtr2 = std::unique_ptr<T, DeleterDtor2>;
// Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator2&` everywhere, we
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
// wrap this in a singleton. A monostate requires explicit initialization,
// which we prefer to avoid because there are many main() functions.
class Allocator2 {
class Allocator {
public:
// Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread.
Allocator2(const BoundedTopology& topology, bool enable_bind);
Allocator(const BoundedTopology& topology, bool enable_bind);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing.

View File

@ -102,7 +102,7 @@ static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
return padded_num;
}
static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
static size_t Stride(const Allocator& allocator, const MatPtr& mat,
MatPadding padding) {
switch (padding) {
case MatPadding::kPacked:
@ -119,7 +119,7 @@ static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const Allocator& allocator = ThreadingContext::Get().allocator;
const size_t stride = Stride(allocator, mat, padding);
const size_t num = mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`

View File

@ -282,11 +282,11 @@ void ZeroInit(MatPtr& mat);
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB.
// `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is typically 4KiB.
// To avoid remote accesses, we would thus pad each row to that, which results
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
// by pulling rows forward by a cyclic offset, which is still a multiple of the
// cache line size. This requires an additional `Allocator2::QuantumBytes()` of
// cache line size. This requires an additional `Allocator::QuantumBytes()` of
// padding after also rounding up to that, which considerably increases size for
// tall and skinny tensors.
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
@ -295,7 +295,7 @@ static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
// Constexpr version (upper bound) for allocating storage in MatMul.
template <typename T>
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
constexpr size_t quantum = Allocator2::MaxQuantum<T>();
constexpr size_t quantum = Allocator::MaxQuantum<T>();
return hwy::RoundUpTo(cols, quantum) + quantum;
}
@ -387,7 +387,7 @@ MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
template <typename T>
class RowPtr {
public:
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols,
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols,
size_t stride)
: row0_(row0),
stride_(stride),
@ -414,7 +414,7 @@ class RowPtr {
}
}
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols)
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols)
: RowPtr(allocator, row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const {
@ -480,7 +480,7 @@ class RowVectorBatch {
// we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(const Allocator2& allocator, Extents2D extents,
RowVectorBatch(const Allocator& allocator, Extents2D extents,
size_t stride = 0)
: extents_(extents) {
if (stride == 0) {
@ -529,14 +529,14 @@ class RowVectorBatch {
};
template <typename T>
RowPtr<T> RowPtrFromBatch(const Allocator2& allocator,
RowPtr<T> RowPtrFromBatch(const Allocator& allocator,
RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
row_vectors.Stride());
}
template <typename T>
RowVectorBatch<T> AllocateAlignedRows(const Allocator2& allocator,
RowVectorBatch<T> AllocateAlignedRows(const Allocator& allocator,
Extents2D extents) {
return RowVectorBatch<T>(
allocator, extents,

View File

@ -109,7 +109,7 @@ static Pinning& GetPinning() {
return pinning;
}
static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers,
static PoolPtr MakePool(const Allocator& allocator, size_t num_workers,
std::optional<size_t> node = std::nullopt) {
// `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero.
@ -136,7 +136,7 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
}
NestedPools::NestedPools(const BoundedTopology& topology,
const Allocator2& allocator, size_t max_threads,
const Allocator& allocator, size_t max_threads,
Tristate pin) {
GetPinning().SetPolicy(pin);
packages_.resize(topology.NumPackages());
@ -175,7 +175,7 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
}
NestedPools::Package::Package(const BoundedTopology& topology,
const Allocator2& allocator, size_t pkg_idx,
const Allocator& allocator, size_t pkg_idx,
size_t max_workers_per_package) {
// Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(pkg_idx));

View File

@ -74,7 +74,7 @@ class NestedPools {
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
// only impose upper bounds on the number of detected packages and clusters
// rather than defining the actual number of threads.
NestedPools(const BoundedTopology& topology, const Allocator2& allocator,
NestedPools(const BoundedTopology& topology, const Allocator& allocator,
size_t max_threads = 0, Tristate pin = Tristate::kDefault);
bool AllPinned() const { return all_pinned_; }
@ -148,7 +148,7 @@ class NestedPools {
class Package {
public:
Package() = default; // for vector
Package(const BoundedTopology& topology, const Allocator2& allocator,
Package(const BoundedTopology& topology, const Allocator& allocator,
size_t pkg_idx, size_t max_workers_per_package);
size_t NumClusters() const { return clusters_.size(); }

View File

@ -26,37 +26,37 @@ namespace gcpp {
static ThreadingArgs s_args;
// Cannot use magic static because that does not support `Invalidate`, hence
// allocate manually.
static std::unique_ptr<ThreadingContext2> s_ctx;
static std::unique_ptr<ThreadingContext> s_ctx;
static std::mutex s_ctx_mutex;
/*static*/ void ThreadingContext2::SetArgs(const ThreadingArgs& args) {
/*static*/ void ThreadingContext::SetArgs(const ThreadingArgs& args) {
s_ctx_mutex.lock();
HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late.
s_args = args;
s_ctx_mutex.unlock();
}
/*static*/ bool ThreadingContext2::IsInitialized() {
/*static*/ bool ThreadingContext::IsInitialized() {
s_ctx_mutex.lock();
const bool initialized = !!s_ctx;
s_ctx_mutex.unlock();
return initialized;
}
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
/*static*/ ThreadingContext& ThreadingContext::Get() {
PROFILER_FUNC;
// We do not bother with double-checked locking because it requires an
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
// callers can cache the result and call less often.
s_ctx_mutex.lock();
if (HWY_UNLIKELY(!s_ctx)) {
s_ctx = std::make_unique<ThreadingContext2>(PrivateToken());
s_ctx = std::make_unique<ThreadingContext>(PrivateToken());
}
s_ctx_mutex.unlock();
return *s_ctx;
}
/*static*/ void ThreadingContext2::ThreadHostileInvalidate() {
/*static*/ void ThreadingContext::ThreadHostileInvalidate() {
// Deliberately avoid taking the lock so that tsan can warn if this is
// called concurrently with other calls to `Get`.
s_ctx.reset();
@ -64,7 +64,7 @@ static std::mutex s_ctx_mutex;
// WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would
// deadlock.
ThreadingContext2::ThreadingContext2(ThreadingContext2::PrivateToken)
ThreadingContext::ThreadingContext(ThreadingContext::PrivateToken)
: topology(BoundedSlice(s_args.skip_packages, s_args.max_packages),
BoundedSlice(s_args.skip_clusters, s_args.max_clusters),
BoundedSlice(s_args.skip_lps, s_args.max_lps)),

View File

@ -87,7 +87,7 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
// Lazily-initialized singleton with support for passing in arguments from
// `ThreadingArgs` and re-initializing with different arguments.
class ThreadingContext2 {
class ThreadingContext {
struct PrivateToken {}; // avoids constructing directly
public:
@ -112,7 +112,7 @@ class ThreadingContext2 {
// hence we prefer not to pull `std::shared_ptr` into the interface.
//
// To reduce overhead, callers should cache the result and call less often.
static ThreadingContext2& Get();
static ThreadingContext& Get();
// Invalidates the singleton before or after a call to `Get`. This allows
// changing the arguments between tests. Callers must again call `Get`
@ -121,10 +121,10 @@ class ThreadingContext2 {
// Also useful to suppress memory leak warnings in tests.
static void ThreadHostileInvalidate();
explicit ThreadingContext2(PrivateToken); // only called via `Get`.
explicit ThreadingContext(PrivateToken); // only called via `Get`.
BoundedTopology topology;
Allocator2 allocator;
Allocator allocator;
NestedPools pools;
};

View File

@ -383,7 +383,7 @@ TEST(ThreadingTest, BenchJoin) {
}
};
NestedPools& pools = ThreadingContext2::Get().pools;
NestedPools& pools = ThreadingContext::Get().pools;
// Use last package because the main thread has been pinned to it.
const size_t pkg_idx = pools.NumPackages() - 1;