mirror of https://github.com/google/gemma.cpp.git
Rename-only: remove Allocator2 etc suffixes now that refactoring is complete
PiperOrigin-RevId: 755397220
This commit is contained in:
parent
8d0882b966
commit
275135d7e8
|
|
@ -66,14 +66,14 @@ hwy::ThreadPool& ThreadHostileGetPool() {
|
|||
// can safely call `SetArgs` only once, because it would assert otherwise.
|
||||
// This is preferable to calling `ThreadHostileInvalidate`, because we would
|
||||
// repeat the topology initialization for every test.
|
||||
if (!ThreadingContext2::IsInitialized()) {
|
||||
if (!ThreadingContext::IsInitialized()) {
|
||||
gcpp::ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 8;
|
||||
threading_args.pin = Tristate::kFalse;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
}
|
||||
return ThreadingContext2::Get().pools.Pool();
|
||||
return ThreadingContext::Get().pools.Pool();
|
||||
}
|
||||
|
||||
void TestMatMulVJP() {
|
||||
|
|
@ -203,7 +203,7 @@ void TestEndToEnd() {
|
|||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim,
|
||||
ThreadingContext::Get().allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
|
|
|
|||
|
|
@ -45,9 +45,9 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 1;
|
||||
threading_args.pin = Tristate::kFalse;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext2::Get());
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext::Get());
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
|
|
|
|||
|
|
@ -67,7 +67,7 @@ template <typename T>
|
|||
class WeightsWrapper {
|
||||
public:
|
||||
explicit WeightsWrapper(const ModelConfig& config) : weights_(config) {
|
||||
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
|
||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
||||
weights_.AllocateForTest(owners_, pool);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@
|
|||
namespace gcpp {
|
||||
|
||||
// Aborts if any keys differ, because then blobs are not comparable.
|
||||
void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) {
|
||||
void CompareKeys(const BlobReader& reader1, const BlobReader& reader2) {
|
||||
if (reader1.Keys().size() != reader2.Keys().size()) {
|
||||
HWY_ABORT("#keys mismatch: %zu vs %zu\n", reader1.Keys().size(),
|
||||
reader2.Keys().size());
|
||||
|
|
@ -49,13 +49,13 @@ void CompareKeys(const BlobReader2& reader1, const BlobReader2& reader2) {
|
|||
}
|
||||
|
||||
using KeyVec = std::vector<std::string>;
|
||||
using RangeVec = std::vector<BlobRange2>;
|
||||
using RangeVec = std::vector<BlobRange>;
|
||||
|
||||
RangeVec AllRanges(const KeyVec& keys, const BlobReader2& reader) {
|
||||
RangeVec AllRanges(const KeyVec& keys, const BlobReader& reader) {
|
||||
RangeVec ranges;
|
||||
ranges.reserve(keys.size());
|
||||
for (const std::string& key : keys) {
|
||||
const BlobRange2* range = reader.Find(key);
|
||||
const BlobRange* range = reader.Find(key);
|
||||
if (!range) {
|
||||
HWY_ABORT("Key %s not found, but was in KeyVec\n", key.c_str());
|
||||
}
|
||||
|
|
@ -82,7 +82,7 @@ void CompareRangeSizes(const KeyVec& keys, const RangeVec& ranges1,
|
|||
// Total amount to allocate for all blobs.
|
||||
size_t TotalBytes(const RangeVec& ranges) {
|
||||
size_t total_bytes = 0;
|
||||
for (const BlobRange2& range : ranges) {
|
||||
for (const BlobRange& range : ranges) {
|
||||
total_bytes += range.bytes;
|
||||
}
|
||||
return total_bytes;
|
||||
|
|
@ -95,7 +95,7 @@ using BlobVec = std::vector<ByteSpan>; // in order of keys
|
|||
// Assigns pointers within the single allocation and updates `pos`.
|
||||
BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
|
||||
BlobVec blobs;
|
||||
for (const BlobRange2& range : ranges) {
|
||||
for (const BlobRange& range : ranges) {
|
||||
blobs.push_back(ByteSpan(all_blobs.get() + pos, range.bytes));
|
||||
pos += range.bytes;
|
||||
}
|
||||
|
|
@ -104,7 +104,7 @@ BlobVec ReserveMemory(const RangeVec& ranges, BytePtr& all_blobs, size_t& pos) {
|
|||
|
||||
// Reads one set of blobs in parallel (helpful if in disk cache).
|
||||
// Aborts on error.
|
||||
void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs,
|
||||
void ReadBlobs(BlobReader& reader, const RangeVec& ranges, BlobVec& blobs,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(reader.Keys().size() == blobs.size());
|
||||
HWY_ASSERT(ranges.size() == blobs.size());
|
||||
|
|
@ -116,7 +116,7 @@ void ReadBlobs(BlobReader2& reader, const RangeVec& ranges, BlobVec& blobs,
|
|||
}
|
||||
|
||||
// Parallelizes ReadBlobs across (two) packages, if available.
|
||||
void ReadBothBlobs(BlobReader2& reader1, BlobReader2& reader2,
|
||||
void ReadBothBlobs(BlobReader& reader1, BlobReader& reader2,
|
||||
const RangeVec& ranges1, const RangeVec& ranges2,
|
||||
size_t total_bytes, BlobVec& blobs1, BlobVec& blobs2,
|
||||
NestedPools& pools) {
|
||||
|
|
@ -215,8 +215,8 @@ void CompareBlobs(const KeyVec& keys, BlobVec& blobs1, BlobVec& blobs2,
|
|||
// Compares two sbs files, including blob order.
|
||||
void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
||||
const Tristate map = Tristate::kFalse;
|
||||
std::unique_ptr<BlobReader2> reader1 = BlobReader2::Make(Path(path1), map);
|
||||
std::unique_ptr<BlobReader2> reader2 = BlobReader2::Make(Path(path2), map);
|
||||
std::unique_ptr<BlobReader> reader1 = BlobReader::Make(Path(path1), map);
|
||||
std::unique_ptr<BlobReader> reader2 = BlobReader::Make(Path(path2), map);
|
||||
if (!reader1 || !reader2) {
|
||||
HWY_ABORT(
|
||||
"Failed to create readers for files %s %s, see error messages above.\n",
|
||||
|
|
@ -235,7 +235,7 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
|||
BlobVec blobs1 = ReserveMemory(ranges1, all_blobs, pos);
|
||||
BlobVec blobs2 = ReserveMemory(ranges2, all_blobs, pos);
|
||||
|
||||
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||
NestedPools& pools = ThreadingContext::Get().pools;
|
||||
ReadBothBlobs(*reader1, *reader2, ranges1, ranges2, total_bytes, blobs1,
|
||||
blobs2, pools);
|
||||
|
||||
|
|
|
|||
|
|
@ -94,7 +94,7 @@ static_assert(sizeof(Header) == 16);
|
|||
// Additional data may be added only inside new blobs. Changes to the blob
|
||||
// contents or type should be handled by renaming keys.
|
||||
//
|
||||
// This class is for internal use by `BlobReader2` and `BlobWriter2`. Its
|
||||
// This class is for internal use by `BlobReader` and `BlobWriter`. Its
|
||||
// interface is more low-level: fixed-size keys instead of strings.
|
||||
class BlobStore {
|
||||
static constexpr uint32_t kMagic = 0x0A534253; // SBS\n
|
||||
|
|
@ -182,7 +182,7 @@ class BlobStore {
|
|||
padded_dir_bytes - 2 * num_blobs * kU128Bytes);
|
||||
|
||||
// We already zero-initialized the directory padding;
|
||||
// `BlobWriter2::WriteAll` takes care of padding after each blob via an
|
||||
// `BlobWriter::WriteAll` takes care of padding after each blob via an
|
||||
// additional I/O.
|
||||
for (size_t i = 0; i < num_blobs; ++i) {
|
||||
HWY_ASSERT(blobs[i].data() != nullptr);
|
||||
|
|
@ -242,12 +242,12 @@ class BlobStore {
|
|||
void EnqueueWriteForHeaderAndDirectory(std::vector<BlobIO2>& writes) const {
|
||||
const size_t key_idx = 0; // not actually associated with a key/blob
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
|
||||
BlobRange{.offset = 0, .bytes = sizeof(header_), .key_idx = key_idx},
|
||||
// members are const and BlobIO2 requires non-const pointers, and they
|
||||
// are not modified by file writes.
|
||||
const_cast<Header*>(&header_));
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = sizeof(header_),
|
||||
BlobRange{.offset = sizeof(header_),
|
||||
.bytes = PaddedDirEnd(NumBlobs()) - sizeof(header_),
|
||||
.key_idx = key_idx},
|
||||
const_cast<hwy::uint128_t*>(directory_.data()));
|
||||
|
|
@ -289,8 +289,8 @@ class BlobStore {
|
|||
std::vector<hwy::uint128_t> directory_; // two per blob, see `SetRange`.
|
||||
}; // BlobStore
|
||||
|
||||
BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||
const BlobStore& bs, BlobReader2::Mode mode)
|
||||
BlobReader::BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||
const BlobStore& bs, BlobReader::Mode mode)
|
||||
: file_(std::move(file)), file_bytes_(file_bytes), mode_(mode) {
|
||||
HWY_ASSERT(file_ && file_bytes_ != 0);
|
||||
|
||||
|
|
@ -306,12 +306,12 @@ BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
|
|||
size_t bytes;
|
||||
bs.GetRange(key_idx, offset, bytes);
|
||||
ranges_.emplace_back(
|
||||
BlobRange2{.offset = offset, .bytes = bytes, .key_idx = key_idx});
|
||||
BlobRange{.offset = offset, .bytes = bytes, .key_idx = key_idx});
|
||||
key_idx_for_key_[keys_[key_idx]] = key_idx;
|
||||
}
|
||||
|
||||
if (mode_ == Mode::kMap) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
// Verify `kEndAlign` is an upper bound on the page size.
|
||||
if (kEndAlign % allocator.BasePageBytes() != 0) {
|
||||
HWY_ABORT("Please raise an issue about kEndAlign %zu %% page size %zu.",
|
||||
|
|
@ -338,12 +338,12 @@ BlobReader2::BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
|
|||
}
|
||||
}
|
||||
|
||||
void BlobReader2::Enqueue(const BlobRange2& range, void* data) {
|
||||
void BlobReader::Enqueue(const BlobRange& range, void* data) {
|
||||
// Debug-only because there may be many I/O requests (per row).
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
HWY_DASSERT(!IsMapped());
|
||||
HWY_DASSERT(range.offset != 0 && range.bytes != 0 && data != nullptr);
|
||||
const BlobRange2& blob_range = Range(range.key_idx);
|
||||
const BlobRange& blob_range = Range(range.key_idx);
|
||||
HWY_DASSERT(blob_range.End() <= file_bytes_);
|
||||
if (range.End() > blob_range.End()) {
|
||||
HWY_ABORT(
|
||||
|
|
@ -362,15 +362,15 @@ void BlobReader2::Enqueue(const BlobRange2& range, void* data) {
|
|||
// TODO: use preadv for per-tensor batches of sysconf(_SC_IOV_MAX) / IOV_MAX.
|
||||
// - O_DIRECT seems undesirable because we do want to use the OS cache
|
||||
// between consecutive runs.
|
||||
void BlobReader2::ReadAll(hwy::ThreadPool& pool) const {
|
||||
void BlobReader::ReadAll(hwy::ThreadPool& pool) const {
|
||||
PROFILER_ZONE("Startup.ReadAll");
|
||||
HWY_ASSERT(!IsMapped());
|
||||
// >5x speedup from parallel reads when cached.
|
||||
pool.Run(0, requests_.size(), [this](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange2& range = requests_[i].range;
|
||||
const BlobRange& range = requests_[i].range;
|
||||
const uint64_t end = range.End();
|
||||
const std::string& key = keys_[range.key_idx];
|
||||
const BlobRange2& blob_range = Range(range.key_idx);
|
||||
const BlobRange& blob_range = Range(range.key_idx);
|
||||
HWY_ASSERT(blob_range.End() <= file_bytes_);
|
||||
if (end > blob_range.End()) {
|
||||
HWY_ABORT(
|
||||
|
|
@ -387,11 +387,11 @@ void BlobReader2::ReadAll(hwy::ThreadPool& pool) const {
|
|||
}
|
||||
|
||||
// Decides whether to read or map the file.
|
||||
static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
static BlobReader::Mode ChooseMode(uint64_t file_mib, Tristate map) {
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
// User has explicitly requested a map or read via args.
|
||||
if (map == Tristate::kTrue) return BlobReader2::Mode::kMap;
|
||||
if (map == Tristate::kFalse) return BlobReader2::Mode::kRead;
|
||||
if (map == Tristate::kTrue) return BlobReader::Mode::kMap;
|
||||
if (map == Tristate::kFalse) return BlobReader::Mode::kRead;
|
||||
// Else: use heuristics to choose. Note that `FreeMiB` is generally low
|
||||
// because idle memory is used as cache, so do not use it to decide.
|
||||
const size_t total_mib = allocator.TotalMiB();
|
||||
|
|
@ -400,13 +400,13 @@ static BlobReader2::Mode ChooseMode(uint64_t file_mib, Tristate map) {
|
|||
static_cast<size_t>(file_mib), total_mib);
|
||||
}
|
||||
// Large fraction of total.
|
||||
if (file_mib >= total_mib / 3) return BlobReader2::Mode::kMap;
|
||||
if (file_mib >= total_mib / 3) return BlobReader::Mode::kMap;
|
||||
// Big enough that even parallel loading wouldn't be quick.
|
||||
if (file_mib > 50 * 1024) return BlobReader2::Mode::kMap;
|
||||
return BlobReader2::Mode::kRead;
|
||||
if (file_mib > 50 * 1024) return BlobReader::Mode::kMap;
|
||||
return BlobReader::Mode::kRead;
|
||||
}
|
||||
|
||||
std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path,
|
||||
std::unique_ptr<BlobReader> BlobReader::Make(const Path& blob_path,
|
||||
const Tristate map) {
|
||||
if (blob_path.Empty()) HWY_ABORT("No --weights specified.");
|
||||
std::unique_ptr<File> file = OpenFileOrNull(blob_path, "r");
|
||||
|
|
@ -417,10 +417,10 @@ std::unique_ptr<BlobReader2> BlobReader2::Make(const Path& blob_path,
|
|||
// Even if `kMap`, read the directory via the `kRead` mode for simplicity.
|
||||
BlobStore bs(*file);
|
||||
if (!bs.IsValid(file_bytes)) {
|
||||
return std::unique_ptr<BlobReader2>(); // IsValid already printed a warning
|
||||
return std::unique_ptr<BlobReader>(); // IsValid already printed a warning
|
||||
}
|
||||
|
||||
return std::unique_ptr<BlobReader2>(new BlobReader2(
|
||||
return std::unique_ptr<BlobReader>(new BlobReader(
|
||||
std::move(file), file_bytes, bs, ChooseMode(file_bytes >> 20, map)));
|
||||
}
|
||||
|
||||
|
|
@ -434,14 +434,13 @@ static void EnqueueChunks(size_t key_idx, uint64_t offset, uint64_t bytes,
|
|||
for (; offset <= end - kChunkBytes;
|
||||
offset += kChunkBytes, data += kChunkBytes) {
|
||||
writes.emplace_back(
|
||||
BlobRange2{
|
||||
.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
|
||||
BlobRange{.offset = offset, .bytes = kChunkBytes, .key_idx = key_idx},
|
||||
data);
|
||||
}
|
||||
}
|
||||
if (offset != end) {
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
|
||||
BlobRange{.offset = offset, .bytes = end - offset, .key_idx = key_idx},
|
||||
data);
|
||||
}
|
||||
}
|
||||
|
|
@ -472,7 +471,7 @@ static void EnqueueWritesForBlobs(const BlobStore& bs,
|
|||
if (padding != 0) {
|
||||
HWY_ASSERT(padding <= kBlobAlign);
|
||||
writes.emplace_back(
|
||||
BlobRange2{
|
||||
BlobRange{
|
||||
.offset = offset + bytes, .bytes = padding, .key_idx = key_idx},
|
||||
const_cast<uint8_t*>(kZeros));
|
||||
}
|
||||
|
|
@ -484,19 +483,19 @@ static void EnqueueWritesForBlobs(const BlobStore& bs,
|
|||
// remain alive until the last I/O is done.
|
||||
zeros.resize(padding);
|
||||
writes.emplace_back(
|
||||
BlobRange2{.offset = file_end, .bytes = padding, .key_idx = 0},
|
||||
BlobRange{.offset = file_end, .bytes = padding, .key_idx = 0},
|
||||
zeros.data());
|
||||
}
|
||||
}
|
||||
|
||||
void BlobWriter2::Add(const std::string& key, const void* data, size_t bytes) {
|
||||
void BlobWriter::Add(const std::string& key, const void* data, size_t bytes) {
|
||||
HWY_ASSERT(data != nullptr);
|
||||
HWY_ASSERT(bytes != 0);
|
||||
keys_.push_back(KeyFromString(key.c_str()));
|
||||
blobs_.emplace_back(static_cast<const uint8_t*>(data), bytes);
|
||||
}
|
||||
|
||||
void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||
void BlobWriter::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
||||
const size_t num_blobs = keys_.size();
|
||||
HWY_ASSERT(num_blobs != 0);
|
||||
HWY_ASSERT(num_blobs == blobs_.size());
|
||||
|
|
@ -516,7 +515,7 @@ void BlobWriter2::WriteAll(hwy::ThreadPool& pool, const Path& filename) {
|
|||
|
||||
pool.Run(0, writes.size(),
|
||||
[this, &file, &writes](uint64_t i, size_t /*thread*/) {
|
||||
const BlobRange2& range = writes[i].range;
|
||||
const BlobRange& range = writes[i].range;
|
||||
|
||||
if (!file->Write(writes[i].data, range.bytes, range.offset)) {
|
||||
const std::string& key = StringFromKey(keys_[range.key_idx]);
|
||||
|
|
|
|||
|
|
@ -35,20 +35,20 @@
|
|||
namespace gcpp {
|
||||
|
||||
// One blob's extents within the file.
|
||||
struct BlobRange2 {
|
||||
struct BlobRange {
|
||||
uint64_t End() const { return offset + bytes; }
|
||||
|
||||
uint64_t offset = 0;
|
||||
size_t bytes = 0; // We check blobs are not zero-sized.
|
||||
// Index within `BlobReader2::Keys()` for error reporting.
|
||||
// Index within `BlobReader::Keys()` for error reporting.
|
||||
size_t key_idx;
|
||||
};
|
||||
|
||||
// A read or write I/O request, each serviced by one thread in a pool.
|
||||
struct BlobIO2 {
|
||||
BlobIO2(BlobRange2 range, void* data) : range(range), data(data) {}
|
||||
BlobIO2(BlobRange range, void* data) : range(range), data(data) {}
|
||||
|
||||
BlobRange2 range;
|
||||
BlobRange range;
|
||||
void* data; // Modified only if a read request. Read-only for writes.
|
||||
};
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ class BlobStore;
|
|||
// Thread-safe: it is safe to concurrently call all methods except `Enqueue`,
|
||||
// because they are const.
|
||||
// TODO(janwas): split into header and reader/mapper classes.
|
||||
class BlobReader2 {
|
||||
class BlobReader {
|
||||
public:
|
||||
// Parallel I/O into allocated memory, or mapped view of file. The latter is
|
||||
// better when the file is huge, but page faults add noise to measurements.
|
||||
|
|
@ -67,26 +67,26 @@ class BlobReader2 {
|
|||
|
||||
// Acquires ownership of `file` (which must be non-null) and reads its header.
|
||||
// Factory function instead of ctor because this can fail (return null).
|
||||
static std::unique_ptr<BlobReader2> Make(const Path& blob_path,
|
||||
static std::unique_ptr<BlobReader> Make(const Path& blob_path,
|
||||
Tristate map = Tristate::kDefault);
|
||||
|
||||
~BlobReader2() = default;
|
||||
~BlobReader() = default;
|
||||
|
||||
// Returns true if the mode passed to ctor was `kMap` and mapping succeeded.
|
||||
bool IsMapped() const { return mode_ == Mode::kMap; }
|
||||
|
||||
const std::vector<std::string>& Keys() const { return keys_; }
|
||||
|
||||
const BlobRange2& Range(size_t key_idx) const {
|
||||
const BlobRange& Range(size_t key_idx) const {
|
||||
HWY_ASSERT(key_idx < keys_.size());
|
||||
return ranges_[key_idx];
|
||||
}
|
||||
|
||||
// Returns nullptr if not found. O(1).
|
||||
const BlobRange2* Find(const std::string& key) const {
|
||||
const BlobRange* Find(const std::string& key) const {
|
||||
auto it = key_idx_for_key_.find(key);
|
||||
if (it == key_idx_for_key_.end()) return nullptr;
|
||||
const BlobRange2& range = Range(it->second);
|
||||
const BlobRange& range = Range(it->second);
|
||||
HWY_ASSERT(range.offset != 0 && range.bytes != 0);
|
||||
HWY_ASSERT(range.End() <= file_bytes_);
|
||||
return ⦥
|
||||
|
|
@ -95,7 +95,7 @@ class BlobReader2 {
|
|||
// Only if `IsMapped()`: returns blob as a read-only span of `T`. Note that
|
||||
// everything else except `CallWithSpan` is in units of bytes.
|
||||
template <typename T>
|
||||
hwy::Span<const T> MappedSpan(const BlobRange2& range) const {
|
||||
hwy::Span<const T> MappedSpan(const BlobRange& range) const {
|
||||
HWY_ASSERT(IsMapped());
|
||||
HWY_ASSERT(range.bytes % sizeof(T) == 0);
|
||||
return hwy::Span<const T>(
|
||||
|
|
@ -108,7 +108,7 @@ class BlobReader2 {
|
|||
// which an aligned allocation is unnecessary.
|
||||
template <typename T, class Func>
|
||||
bool CallWithSpan(const std::string& key, const Func& func) const {
|
||||
const BlobRange2* range = Find(key);
|
||||
const BlobRange* range = Find(key);
|
||||
if (!range) {
|
||||
HWY_WARN("Blob %s not found, sizeof T=%zu", key.c_str(), sizeof(T));
|
||||
return false;
|
||||
|
|
@ -134,7 +134,7 @@ class BlobReader2 {
|
|||
// The following methods must only be called if `!IsMapped()`.
|
||||
|
||||
// Enqueues a BlobIO2 for `ReadAll` to execute.
|
||||
void Enqueue(const BlobRange2& range, void* data);
|
||||
void Enqueue(const BlobRange& range, void* data);
|
||||
|
||||
// Reads in parallel all enqueued requests to the specified destinations.
|
||||
// Aborts on error.
|
||||
|
|
@ -142,7 +142,7 @@ class BlobReader2 {
|
|||
|
||||
private:
|
||||
// Only for use by `Make`.
|
||||
BlobReader2(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||
BlobReader(std::unique_ptr<File> file, uint64_t file_bytes,
|
||||
const BlobStore& bs, Mode mode);
|
||||
|
||||
const std::unique_ptr<File> file_;
|
||||
|
|
@ -150,7 +150,7 @@ class BlobReader2 {
|
|||
Mode mode_;
|
||||
|
||||
std::vector<std::string> keys_;
|
||||
std::vector<BlobRange2> ranges_;
|
||||
std::vector<BlobRange> ranges_;
|
||||
std::unordered_map<std::string, size_t> key_idx_for_key_;
|
||||
|
||||
MapPtr mapped_; // only if `kMap`
|
||||
|
|
@ -160,7 +160,7 @@ class BlobReader2 {
|
|||
// Collects references to blobs and writes them all at once with parallel I/O.
|
||||
// Thread-compatible: independent instances can be used concurrently, but it
|
||||
// does not make sense to call the methods concurrently.
|
||||
class BlobWriter2 {
|
||||
class BlobWriter {
|
||||
public:
|
||||
void Add(const std::string& key, const void* data, size_t bytes);
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ class BlobStoreTest : public testing::Test {};
|
|||
#endif
|
||||
|
||||
void TestWithMapped(Tristate map) {
|
||||
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
|
||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
||||
|
||||
static const std::array<float, 4> kOriginalData = {-1, 0, 3.14159, 2.71828};
|
||||
|
||||
|
|
@ -51,7 +51,7 @@ void TestWithMapped(Tristate map) {
|
|||
|
||||
const std::string keyA("0123456789abcdef"); // max 16 characters
|
||||
const std::string keyB("q");
|
||||
BlobWriter2 writer;
|
||||
BlobWriter writer;
|
||||
writer.Add(keyA, "DATA", 5);
|
||||
writer.Add(keyB, buffer.data(), sizeof(buffer));
|
||||
writer.WriteAll(pool, path);
|
||||
|
|
@ -59,14 +59,14 @@ void TestWithMapped(Tristate map) {
|
|||
|
||||
std::fill(buffer.begin(), buffer.end(), 0);
|
||||
|
||||
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map);
|
||||
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
|
||||
HWY_ASSERT(reader);
|
||||
|
||||
HWY_ASSERT_EQ(reader->Keys().size(), 2);
|
||||
HWY_ASSERT_STRING_EQ(reader->Keys()[0].c_str(), keyA.c_str());
|
||||
HWY_ASSERT_STRING_EQ(reader->Keys()[1].c_str(), keyB.c_str());
|
||||
|
||||
const BlobRange2* range = reader->Find(keyA);
|
||||
const BlobRange* range = reader->Find(keyA);
|
||||
HWY_ASSERT(range);
|
||||
const uint64_t offsetA = range->offset;
|
||||
HWY_ASSERT_EQ(offsetA, 256); // kBlobAlign
|
||||
|
|
@ -80,9 +80,9 @@ void TestWithMapped(Tristate map) {
|
|||
if (!reader->IsMapped()) {
|
||||
char str[5];
|
||||
reader->Enqueue(
|
||||
BlobRange2{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
|
||||
BlobRange{.offset = offsetA, .bytes = sizeof(str), .key_idx = 0}, str);
|
||||
reader->Enqueue(
|
||||
BlobRange2{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
|
||||
BlobRange{.offset = offsetB, .bytes = sizeof(buffer), .key_idx = 1},
|
||||
buffer.data());
|
||||
reader->ReadAll(pool);
|
||||
HWY_ASSERT_STRING_EQ("DATA", str);
|
||||
|
|
@ -111,7 +111,7 @@ TEST(BlobStoreTest, TestReadWrite) {
|
|||
|
||||
// Ensures padding works for any number of random-sized blobs.
|
||||
TEST(BlobStoreTest, TestNumBlobs) {
|
||||
hwy::ThreadPool& pool = ThreadingContext2::Get().pools.Pool();
|
||||
hwy::ThreadPool& pool = ThreadingContext::Get().pools.Pool();
|
||||
hwy::RandomState rng;
|
||||
|
||||
for (size_t num_blobs = 1; num_blobs <= 512; ++num_blobs) {
|
||||
|
|
@ -121,7 +121,7 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
HWY_ASSERT(fd > 0);
|
||||
const Path path(path_str);
|
||||
|
||||
BlobWriter2 writer;
|
||||
BlobWriter writer;
|
||||
std::vector<std::string> keys;
|
||||
keys.reserve(num_blobs);
|
||||
std::vector<std::vector<uint8_t>> blobs;
|
||||
|
|
@ -144,13 +144,13 @@ TEST(BlobStoreTest, TestNumBlobs) {
|
|||
writer.WriteAll(pool, path);
|
||||
|
||||
const Tristate map = Tristate::kFalse;
|
||||
std::unique_ptr<BlobReader2> reader = BlobReader2::Make(path, map);
|
||||
std::unique_ptr<BlobReader> reader = BlobReader::Make(path, map);
|
||||
HWY_ASSERT(reader);
|
||||
HWY_ASSERT_EQ(reader->Keys().size(), num_blobs);
|
||||
pool.Run(0, num_blobs, [&](uint64_t i, size_t /*thread*/) {
|
||||
HWY_ASSERT_STRING_EQ(reader->Keys()[i].c_str(),
|
||||
std::to_string(i).c_str());
|
||||
const BlobRange2* range = reader->Find(keys[i]);
|
||||
const BlobRange* range = reader->Find(keys[i]);
|
||||
HWY_ASSERT(range);
|
||||
HWY_ASSERT_EQ(blobs[i].size(), range->bytes);
|
||||
HWY_ASSERT(reader->CallWithSpan<uint8_t>(
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h" // BlobWriter2
|
||||
#include "compression/blob_store.h" // BlobWriter
|
||||
#include "compression/compress.h" // ScaleWeights
|
||||
#include "compression/io.h" // Path
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
|
|
@ -88,7 +88,7 @@ class SbsWriterImpl : public ISbsWriter {
|
|||
}
|
||||
|
||||
public:
|
||||
SbsWriterImpl() : pool_(ThreadingContext2::Get().pools.Pool()) {}
|
||||
SbsWriterImpl() : pool_(ThreadingContext::Get().pools.Pool()) {}
|
||||
|
||||
void Insert(const char* name, F32Span weights, Type type,
|
||||
const TensorInfo& tensor_info) override {
|
||||
|
|
@ -123,7 +123,7 @@ class SbsWriterImpl : public ISbsWriter {
|
|||
hwy::ThreadPool& pool_;
|
||||
MatOwners mat_owners_;
|
||||
CompressWorkingSet working_set_;
|
||||
BlobWriter2 writer_;
|
||||
BlobWriter writer_;
|
||||
std::vector<uint32_t> serialized_mat_ptrs_;
|
||||
};
|
||||
|
||||
|
|
@ -141,7 +141,7 @@ HWY_EXPORT(NewSbsWriter);
|
|||
SbsWriter::SbsWriter() : impl_(HWY_DYNAMIC_DISPATCH(NewSbsWriter)()) {}
|
||||
|
||||
SbsReader::SbsReader(const std::string& path)
|
||||
: reader_(gcpp::BlobReader2::Make(Path(path))), model_(*reader_) {}
|
||||
: reader_(gcpp::BlobReader::Make(Path(path))), model_(*reader_) {}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // HWY_ONCE
|
||||
|
|
|
|||
|
|
@ -77,8 +77,8 @@ class SbsReader {
|
|||
const MatPtr* FindMat(const char* name) const { return model_.FindMat(name); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<gcpp::BlobReader2> reader_;
|
||||
gcpp::ModelStore2 model_;
|
||||
std::unique_ptr<gcpp::BlobReader> reader_;
|
||||
gcpp::ModelStore model_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading,
|
|||
char* dt = ctime(&now); // NOLINT
|
||||
char cpu100[100] = "unknown";
|
||||
(void)hwy::platform::GetCpuString(cpu100);
|
||||
const ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||
const ThreadingContext& ctx = ThreadingContext::Get();
|
||||
|
||||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ class GemmaEnv {
|
|||
GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading,
|
||||
const InferenceArgs& inference);
|
||||
// Avoid memory leaks in test.
|
||||
~GemmaEnv() { ThreadingContext2::ThreadHostileInvalidate(); }
|
||||
~GemmaEnv() { ThreadingContext::ThreadHostileInvalidate(); }
|
||||
|
||||
MatMulEnv& Env() { return env_; }
|
||||
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ struct Activations {
|
|||
size_t cache_pos_size = 0;
|
||||
|
||||
void Allocate(size_t batch_size, MatMulEnv* env) {
|
||||
const Allocator2& allocator = env->ctx.allocator;
|
||||
const Allocator& allocator = env->ctx.allocator;
|
||||
|
||||
post_qk = layer_config.post_qk;
|
||||
const size_t model_dim = weights_config.model_dim;
|
||||
|
|
|
|||
|
|
@ -561,7 +561,7 @@ class GemmaAttention {
|
|||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const hwy::Divisor& div_seq_len_;
|
||||
const KVCaches& kv_caches_;
|
||||
const Allocator2& allocator_;
|
||||
const Allocator& allocator_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
|
|
@ -749,7 +749,7 @@ class VitAttention {
|
|||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerConfig& layer_config_;
|
||||
const Allocator2& allocator_;
|
||||
const Allocator& allocator_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
|
|
@ -789,7 +789,7 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
|||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
const Allocator2& allocator = activations.env->ctx.allocator;
|
||||
const Allocator& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
|
||||
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||
|
|
@ -847,7 +847,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
|||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
const Allocator2& allocator = activations.env->ctx.allocator;
|
||||
const Allocator& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||
|
||||
|
|
@ -1416,7 +1416,7 @@ bool DecodeStepT(const ModelConfig& config, const ModelWeightsPtrs<T>& weights,
|
|||
//
|
||||
// `kv_caches` is for the batch, size must match `queries_prompt`.
|
||||
template <typename T>
|
||||
void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights,
|
||||
void GenerateT(const ModelStore& model, const ModelWeightsPtrs<T>& weights,
|
||||
Activations& activations, const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
const QueriesPos& queries_pos_in,
|
||||
|
|
@ -1508,7 +1508,7 @@ void GenerateT(const ModelStore2& model, const ModelWeightsPtrs<T>& weights,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateSingleT(const ModelStore2& model,
|
||||
void GenerateSingleT(const ModelStore& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
|
|
@ -1532,7 +1532,7 @@ void GenerateSingleT(const ModelStore2& model,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateBatchT(const ModelStore2& model,
|
||||
void GenerateBatchT(const ModelStore& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt,
|
||||
|
|
@ -1573,7 +1573,7 @@ void GenerateBatchT(const ModelStore2& model,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
void GenerateImageTokensT(const ModelStore2& model,
|
||||
void GenerateImageTokensT(const ModelStore& model,
|
||||
const ModelWeightsPtrs<T>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens,
|
||||
|
|
@ -1599,7 +1599,7 @@ void GenerateImageTokensT(const ModelStore2& model,
|
|||
// These are extern functions defined by instantiations/*.cc, which include this
|
||||
// 'header' after defining `GEMMA_TYPE`.
|
||||
void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
||||
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos,
|
||||
size_t prefix_end, KVCache& kv_cache, MatMulEnv* env,
|
||||
TimingInfo& timing_info) {
|
||||
|
|
@ -1609,7 +1609,7 @@ void GenerateSingle( // NOLINT(misc-definitions-in-headers)
|
|||
}
|
||||
|
||||
void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
||||
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config,
|
||||
const QueriesPromptTokens& queries_prompt, const QueriesPos& queries_pos,
|
||||
const QueriesPos& queries_prefix_end, const KVCaches& kv_caches,
|
||||
|
|
@ -1620,7 +1620,7 @@ void GenerateBatch( // NOLINT(misc-definitions-in-headers)
|
|||
}
|
||||
|
||||
void GenerateImageTokens( // NOLINT(misc-definitions-in-headers)
|
||||
const ModelStore2& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const ModelStore& model, const ModelWeightsPtrs<GEMMA_TYPE>& weights,
|
||||
const RuntimeConfig& runtime_config, const Image& image,
|
||||
ImageTokens& image_tokens, MatMulEnv* env) {
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(GenerateImageTokensT<GEMMA_TYPE>)
|
||||
|
|
|
|||
|
|
@ -47,13 +47,13 @@ namespace gcpp {
|
|||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
||||
// Placeholder for internal init, do not modify.
|
||||
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
return MatMulEnv(ThreadingContext2::Get());
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
return MatMulEnv(ThreadingContext::Get());
|
||||
}
|
||||
|
||||
Gemma::Gemma(const LoaderArgs& loader, MatMulEnv& env)
|
||||
: env_(env),
|
||||
reader_(BlobReader2::Make(loader.weights, loader.map)),
|
||||
reader_(BlobReader::Make(loader.weights, loader.map)),
|
||||
model_(*reader_, loader.tokenizer, loader.wrapping),
|
||||
weights_(model_.Config().weight),
|
||||
chat_template_(model_.Tokenizer(), model_.Config().model) {
|
||||
|
|
@ -74,7 +74,7 @@ Gemma::Gemma(const ModelConfig& config, GemmaTokenizer&& tokenizer,
|
|||
Gemma::~Gemma() = default;
|
||||
|
||||
void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
|
||||
BlobWriter2 writer;
|
||||
BlobWriter writer;
|
||||
const std::vector<uint32_t> serialized_mat_ptrs =
|
||||
weights_.AddTensorDataToWriter(writer);
|
||||
WriteSingleFile(model_.Config(), model_.Tokenizer(), serialized_mat_ptrs,
|
||||
|
|
@ -90,17 +90,17 @@ void Gemma::Save(const Path& weights_path, hwy::ThreadPool& pool) const {
|
|||
// instead of `WeightsPtrs<T>`.
|
||||
#define GEMMA_DECLARE(WEIGHT_TYPE) \
|
||||
extern void GenerateSingle( \
|
||||
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const PromptTokens& prompt, \
|
||||
size_t pos, size_t prefix_end, KVCache& kv_cache, MatMulEnv* env, \
|
||||
TimingInfo& timing_info); \
|
||||
extern void GenerateBatch( \
|
||||
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const QueriesPromptTokens& prompts, \
|
||||
const QueriesPos& queries_pos, const QueriesPos& queries_prefix_end, \
|
||||
const KVCaches& kv_caches, MatMulEnv* env, TimingInfo& timing_info); \
|
||||
extern void GenerateImageTokens( \
|
||||
const ModelStore2& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const ModelStore& model, const ModelWeightsPtrs<WEIGHT_TYPE>& weights, \
|
||||
const RuntimeConfig& runtime_config, const Image& image, \
|
||||
ImageTokens& image_tokens, MatMulEnv* env);
|
||||
GEMMA_DECLARE(float)
|
||||
|
|
|
|||
|
|
@ -160,8 +160,8 @@ class Gemma {
|
|||
|
||||
private:
|
||||
MatMulEnv& env_;
|
||||
std::unique_ptr<BlobReader2> reader_; // null for second ctor
|
||||
ModelStore2 model_;
|
||||
std::unique_ptr<BlobReader> reader_; // null for second ctor
|
||||
ModelStore model_;
|
||||
WeightsOwner weights_;
|
||||
GemmaChatTemplate chat_template_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ static void WarnIfExtra(const IFields::ReadResult& result, const char* name) {
|
|||
|
||||
// Returns the serialized tokenizer (std::string is required for proto).
|
||||
// Reads it from a blob or from a separate file if pre-2025.
|
||||
static std::string ReadTokenizer(BlobReader2& reader,
|
||||
static std::string ReadTokenizer(BlobReader& reader,
|
||||
const Path& tokenizer_path) {
|
||||
std::string tokenizer;
|
||||
// Check prevents `CallWithSpan` from printing a warning.
|
||||
|
|
@ -107,7 +107,7 @@ class TypePrefix {
|
|||
}
|
||||
}
|
||||
|
||||
TypePrefix(const KeyVec& keys, const BlobReader2& reader) {
|
||||
TypePrefix(const KeyVec& keys, const BlobReader& reader) {
|
||||
for (size_t key_idx = 0; key_idx < keys.size(); ++key_idx) {
|
||||
const std::string& key = keys[key_idx];
|
||||
const Type type = TypeFromChar(key[0]);
|
||||
|
|
@ -200,7 +200,7 @@ static int DeduceLayerTypes(const KeyVec& keys) {
|
|||
|
||||
// `wrapping_override` is forwarded from the command line. For pre-2025 files
|
||||
// without `ModelConfig`, it is the only way to force PT.
|
||||
static ModelConfig ReadOrDeduceConfig(BlobReader2& reader,
|
||||
static ModelConfig ReadOrDeduceConfig(BlobReader& reader,
|
||||
Tristate wrapping_override) {
|
||||
const TypePrefix type_prefix(reader.Keys(), reader);
|
||||
Type deduced_weight = Type::kUnknown;
|
||||
|
|
@ -244,7 +244,7 @@ static ModelConfig ReadOrDeduceConfig(BlobReader2& reader,
|
|||
ChooseWrapping(config.model, wrapping_override));
|
||||
}
|
||||
|
||||
static std::vector<float> ReadScales(BlobReader2& reader,
|
||||
static std::vector<float> ReadScales(BlobReader& reader,
|
||||
const ModelConfig& config) {
|
||||
std::vector<float> scales;
|
||||
// Check first to prevent `CallWithSpan` from printing a warning. This blob is
|
||||
|
|
@ -260,7 +260,7 @@ static std::vector<float> ReadScales(BlobReader2& reader,
|
|||
}
|
||||
|
||||
// Single-file format: reads `MatPtr` from the blob; returns false if not found.
|
||||
bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
|
||||
bool ModelStore::ReadMatPtrs(BlobReader& reader) {
|
||||
// Check first to prevent `CallWithSpan` from printing a warning.
|
||||
if (!reader.Find(kMatPtrsName)) return false;
|
||||
|
||||
|
|
@ -282,7 +282,7 @@ bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
|
|||
|
||||
// Retrieve actual key index because a writer may have written other
|
||||
// blobs before the tensor data.
|
||||
const BlobRange2* range = reader.Find(mat.Name());
|
||||
const BlobRange* range = reader.Find(mat.Name());
|
||||
HWY_ASSERT(range);
|
||||
const size_t key_idx = range->key_idx;
|
||||
AddMatPtr(key_idx, mat);
|
||||
|
|
@ -302,7 +302,7 @@ bool ModelStore2::ReadMatPtrs(BlobReader2& reader) {
|
|||
}
|
||||
|
||||
// Pre-2025 format: synthesizes `MatPtr` from the blob names if `!ReadMatPtrs`.
|
||||
void ModelStore2::CreateMatPtrs(BlobReader2& reader) {
|
||||
void ModelStore::CreateMatPtrs(BlobReader& reader) {
|
||||
const TensorInfoRegistry tensors(config_);
|
||||
|
||||
const KeyVec& keys = reader.Keys();
|
||||
|
|
@ -329,7 +329,7 @@ void ModelStore2::CreateMatPtrs(BlobReader2& reader) {
|
|||
HWY_ASSERT(mat_ptrs_.size() == key_idx_.size());
|
||||
}
|
||||
|
||||
ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path,
|
||||
ModelStore::ModelStore(BlobReader& reader, const Path& tokenizer_path,
|
||||
Tristate wrapping)
|
||||
: config_(ReadOrDeduceConfig(reader, wrapping)),
|
||||
tokenizer_(ReadTokenizer(reader, tokenizer_path)) {
|
||||
|
|
@ -348,12 +348,12 @@ ModelStore2::ModelStore2(BlobReader2& reader, const Path& tokenizer_path,
|
|||
HWY_ASSERT(key_idx_.size() == mat_ptrs_.size());
|
||||
}
|
||||
|
||||
ModelStore2::~ModelStore2() {
|
||||
ModelStore::~ModelStore() {
|
||||
// Sanity check: ensure all scales were consumed.
|
||||
HWY_ASSERT(scales_consumed_ == scales_.size());
|
||||
}
|
||||
|
||||
const MatPtr* ModelStore2::FindMat(const char* name) const {
|
||||
const MatPtr* ModelStore::FindMat(const char* name) const {
|
||||
auto it = mat_idx_for_name_.find(name);
|
||||
if (it == mat_idx_for_name_.end()) return nullptr;
|
||||
const size_t mat_idx = it->second;
|
||||
|
|
@ -362,7 +362,7 @@ const MatPtr* ModelStore2::FindMat(const char* name) const {
|
|||
return file_mat;
|
||||
}
|
||||
|
||||
bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
|
||||
bool ModelStore::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
|
||||
const MatPtr* file_mat = FindMat(mat.Name());
|
||||
if (!file_mat) return false;
|
||||
if (file_mat->Rows() != mat.Rows() || file_mat->Cols() != mat.Cols()) {
|
||||
|
|
@ -390,14 +390,14 @@ bool ModelStore2::FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const {
|
|||
}
|
||||
|
||||
static void AddBlob(const char* name, const std::vector<uint32_t>& data,
|
||||
BlobWriter2& writer) {
|
||||
BlobWriter& writer) {
|
||||
HWY_ASSERT(!data.empty());
|
||||
writer.Add(name, data.data(), data.size() * sizeof(data[0]));
|
||||
}
|
||||
|
||||
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
|
||||
const std::vector<uint32_t>& serialized_mat_ptrs,
|
||||
BlobWriter2& writer, hwy::ThreadPool& pool,
|
||||
BlobWriter& writer, hwy::ThreadPool& pool,
|
||||
const Path& path) {
|
||||
HWY_ASSERT(config.model != Model::UNKNOWN);
|
||||
HWY_ASSERT(config.weight != Type::kUnknown);
|
||||
|
|
|
|||
|
|
@ -48,16 +48,16 @@ namespace gcpp {
|
|||
// tokenizer in a separate file, encoded tensor type in a prefix of the blob
|
||||
// name, and had a blob for tensor scaling factors. We still support reading
|
||||
// both, but only write single-file format.
|
||||
class ModelStore2 {
|
||||
class ModelStore {
|
||||
public:
|
||||
// Reads from file(s) or aborts on error. The latter two arguments are only
|
||||
// used for pre-2025 files.
|
||||
ModelStore2(BlobReader2& reader, const Path& tokenizer_path = Path(),
|
||||
ModelStore(BlobReader& reader, const Path& tokenizer_path = Path(),
|
||||
Tristate wrapping = Tristate::kDefault);
|
||||
// For optimize_test.cc.
|
||||
ModelStore2(const ModelConfig& config, GemmaTokenizer&& tokenizer)
|
||||
ModelStore(const ModelConfig& config, GemmaTokenizer&& tokenizer)
|
||||
: config_(config), tokenizer_(std::move(tokenizer)) {}
|
||||
~ModelStore2();
|
||||
~ModelStore();
|
||||
|
||||
const ModelConfig& Config() const {
|
||||
HWY_ASSERT(config_.model != Model::UNKNOWN);
|
||||
|
|
@ -72,7 +72,7 @@ class ModelStore2 {
|
|||
|
||||
// Returns false if `mat` is not available for loading, otherwise updates
|
||||
// `mat` with metadata from the file and sets `key_idx` for use by
|
||||
// `BlobReader2`. Called via `ReadOrAllocate` in `weights.cc`.
|
||||
// `BlobReader`. Called via `ReadOrAllocate` in `weights.cc`.
|
||||
bool FindAndUpdateMatPtr(MatPtr& mat, size_t& key_idx) const;
|
||||
|
||||
private:
|
||||
|
|
@ -83,15 +83,15 @@ class ModelStore2 {
|
|||
key_idx_.push_back(key_idx);
|
||||
}
|
||||
|
||||
bool ReadMatPtrs(BlobReader2& reader);
|
||||
void CreateMatPtrs(BlobReader2& reader); // Aborts on error.
|
||||
bool ReadMatPtrs(BlobReader& reader);
|
||||
void CreateMatPtrs(BlobReader& reader); // Aborts on error.
|
||||
|
||||
ModelConfig config_;
|
||||
GemmaTokenizer tokenizer_;
|
||||
|
||||
// All `MatPtr` present in the `BlobStore`, see `ReadMatPtrs`/`CreateMatPtrs`.
|
||||
std::vector<MatPtr> mat_ptrs_;
|
||||
// For each of `mat_ptrs_`, the index within `BlobReader2::Keys()`. This is
|
||||
// For each of `mat_ptrs_`, the index within `BlobReader::Keys()`. This is
|
||||
// not necessarily iota because some blobs are not tensors, and callers may
|
||||
// have added blobs before ours.
|
||||
std::vector<size_t> key_idx_;
|
||||
|
|
@ -108,7 +108,7 @@ class ModelStore2 {
|
|||
// produces a single BlobStore file holding everything required for inference.
|
||||
void WriteSingleFile(const ModelConfig& config, const GemmaTokenizer& tokenizer,
|
||||
const std::vector<uint32_t>& serialized_mat_ptrs,
|
||||
BlobWriter2& writer, hwy::ThreadPool& pool,
|
||||
BlobWriter& writer, hwy::ThreadPool& pool,
|
||||
const Path& path);
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -84,8 +84,8 @@ void LayerWeightsPtrs<NuqStream>::Reshape() {
|
|||
}
|
||||
|
||||
// Aborts on error.
|
||||
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
|
||||
const std::vector<BlobRange2>& ranges,
|
||||
static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader& reader,
|
||||
const std::vector<BlobRange>& ranges,
|
||||
MatOwners& mat_owners, const MatPadding padding,
|
||||
hwy::ThreadPool& pool) {
|
||||
HWY_ASSERT(mats.size() == ranges.size());
|
||||
|
|
@ -121,7 +121,7 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
|
|||
const size_t mem_stride_bytes = mats[i]->Stride() * mats[i]->ElementBytes();
|
||||
uint8_t* row = mats[i]->RowT<uint8_t>(0);
|
||||
for (size_t r = 0; r < mats[i]->Rows(); ++r) {
|
||||
reader.Enqueue(BlobRange2{.offset = offset,
|
||||
reader.Enqueue(BlobRange{.offset = offset,
|
||||
.bytes = file_bytes_per_row,
|
||||
.key_idx = ranges[i].key_idx},
|
||||
row);
|
||||
|
|
@ -134,11 +134,11 @@ static void MapOrRead(const std::vector<MatPtr*>& mats, BlobReader2& reader,
|
|||
reader.ReadAll(pool);
|
||||
}
|
||||
|
||||
void WeightsOwner::ReadOrAllocate(const ModelStore2& model, BlobReader2& reader,
|
||||
void WeightsOwner::ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
||||
hwy::ThreadPool& pool) {
|
||||
// List of tensors to read/map, and where from.
|
||||
std::vector<MatPtr*> mats;
|
||||
std::vector<BlobRange2> ranges;
|
||||
std::vector<BlobRange> ranges;
|
||||
|
||||
// Padding is inserted when reading row by row, except for NUQ tensors.
|
||||
const MatPadding padding = MatPadding::kOdd;
|
||||
|
|
@ -244,7 +244,7 @@ void WeightsOwner::Reshape(hwy::ThreadPool& pool) {
|
|||
}
|
||||
|
||||
std::vector<uint32_t> WeightsOwner::AddTensorDataToWriter(
|
||||
BlobWriter2& writer) const {
|
||||
BlobWriter& writer) const {
|
||||
std::vector<uint32_t> serialized_mat_ptrs;
|
||||
CallT([&](const auto& weights) {
|
||||
weights->ForEachTensor(nullptr, nullptr, [&](const TensorArgs& t) {
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/blob_store.h" // BlobWriter2
|
||||
#include "compression/blob_store.h" // BlobWriter
|
||||
#include "compression/shared.h" // IsF32
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "gemma/model_store.h" // ModelStore
|
||||
|
|
@ -519,7 +519,7 @@ class WeightsOwner {
|
|||
|
||||
// Reads tensor data from `BlobStore`, or for tensors marked `kOnlyAllocate`,
|
||||
// allocates memory and reshapes. Aborts on error.
|
||||
void ReadOrAllocate(const ModelStore2& model, BlobReader2& reader,
|
||||
void ReadOrAllocate(const ModelStore& model, BlobReader& reader,
|
||||
hwy::ThreadPool& pool);
|
||||
|
||||
// Calls `func(std::unique_ptr<WeightsPtrs<T>>&, args)`. `func` typically
|
||||
|
|
@ -541,7 +541,7 @@ class WeightsOwner {
|
|||
// For writers:
|
||||
|
||||
// Adds one blob for each tensor's data and returns all serialized MatPtr.
|
||||
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter2& writer) const;
|
||||
std::vector<uint32_t> AddTensorDataToWriter(BlobWriter& writer) const;
|
||||
|
||||
// For backprop/:
|
||||
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
|||
// M = A rows, K = A cols, N = C cols.
|
||||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
if (env.print_config || env.print_measurement) {
|
||||
fprintf(stderr, "\n");
|
||||
|
|
@ -160,7 +160,7 @@ void BenchAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||
ThreadingContext& ctx = ThreadingContext::Get();
|
||||
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
|
||||
ctx.pools.PinString());
|
||||
|
||||
|
|
|
|||
|
|
@ -999,7 +999,7 @@ struct TestShortDotsT {
|
|||
const size_t N = hn::Lanes(d);
|
||||
const hn::ScalableTag<float> df; // for CallDot
|
||||
|
||||
const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator;
|
||||
const Allocator& allocator = gcpp::ThreadingContext::Get().allocator;
|
||||
CompressWorkingSet work;
|
||||
std::mt19937 rng;
|
||||
rng.seed(12345);
|
||||
|
|
@ -1099,14 +1099,14 @@ void TestAllDot() {
|
|||
constexpr size_t kMaxWorkers = 15;
|
||||
|
||||
// Reset with cap on workers because we only support `kMaxWorkers`.
|
||||
ThreadingContext2::ThreadHostileInvalidate();
|
||||
ThreadingContext::ThreadHostileInvalidate();
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 1;
|
||||
threading_args.max_lps = kMaxWorkers - 1;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||
const Allocator2& allocator = ctx.allocator;
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
ThreadingContext& ctx = ThreadingContext::Get();
|
||||
const Allocator& allocator = ctx.allocator;
|
||||
|
||||
{ // ensure no profiler zones are active
|
||||
const hn::ScalableTag<float> df;
|
||||
|
|
|
|||
|
|
@ -909,7 +909,7 @@ class MMPerPackage {
|
|||
static constexpr size_t B_stride_max_ =
|
||||
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
|
||||
static constexpr size_t B_storage_max_ =
|
||||
kNR * B_stride_max_ + Allocator2::MaxQuantum<BF16>();
|
||||
kNR * B_stride_max_ + Allocator::MaxQuantum<BF16>();
|
||||
|
||||
// Granularity of `ForNP`. B rows produce C columns, so we
|
||||
// want a multiple of the line size to prevent false sharing.
|
||||
|
|
@ -1175,7 +1175,7 @@ class MMPerPackage {
|
|||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename TA>
|
||||
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
|
||||
const Allocator2& allocator = args_.env->ctx.allocator;
|
||||
const Allocator& allocator = args_.env->ctx.allocator;
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
|
|
@ -1316,7 +1316,7 @@ template <typename TA, typename TB, typename TC>
|
|||
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
const size_t M = A.Extents().rows;
|
||||
const size_t K = A.Extents().cols;
|
||||
const size_t N = B.Extents().rows;
|
||||
|
|
|
|||
|
|
@ -60,7 +60,7 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
|||
// and holds most of their arguments in member variables.
|
||||
class GenerateCandidates {
|
||||
public:
|
||||
GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N,
|
||||
GenerateCandidates(const Allocator& allocator, size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np, bool print_config)
|
||||
: allocator_(allocator),
|
||||
|
|
@ -352,7 +352,7 @@ class GenerateCandidates {
|
|||
}
|
||||
}
|
||||
|
||||
const Allocator2& allocator_;
|
||||
const Allocator& allocator_;
|
||||
const size_t M_;
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
|
|
@ -372,7 +372,7 @@ class GenerateCandidates {
|
|||
} // namespace
|
||||
|
||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
||||
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
|
|
@ -384,7 +384,7 @@ std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
|||
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||
// memory accesses or false sharing, unless there are insufficient per-package
|
||||
// rows for that.
|
||||
static size_t NPMultiple(const Allocator2& allocator, size_t N,
|
||||
static size_t NPMultiple(const Allocator& allocator, size_t N,
|
||||
size_t sizeof_TC, size_t nr, size_t num_packages) {
|
||||
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
|
||||
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
||||
|
|
@ -417,7 +417,7 @@ IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
|||
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
|
||||
}
|
||||
|
||||
MatMulEnv::MatMulEnv(ThreadingContext2& ctx)
|
||||
MatMulEnv::MatMulEnv(ThreadingContext& ctx)
|
||||
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
|
||||
char cpu100[100];
|
||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
|
|
|
|||
20
ops/matmul.h
20
ops/matmul.h
|
|
@ -50,7 +50,7 @@ class MMParallel {
|
|||
static constexpr size_t kMaxPackages = 4;
|
||||
|
||||
// `ctx` must outlive this object.
|
||||
MMParallel(ThreadingContext2& ctx) : ctx_(ctx) {
|
||||
MMParallel(ThreadingContext& ctx) : ctx_(ctx) {
|
||||
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
|
||||
}
|
||||
|
||||
|
|
@ -164,11 +164,11 @@ class MMParallel {
|
|||
}
|
||||
|
||||
private:
|
||||
ThreadingContext2& ctx_;
|
||||
ThreadingContext& ctx_;
|
||||
};
|
||||
|
||||
template <typename TC> // BF16/float for C, double for partial
|
||||
void BindC(const Allocator2& allocator, size_t M, const RowPtr<TC>& C,
|
||||
void BindC(const Allocator& allocator, size_t M, const RowPtr<TC>& C,
|
||||
MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
|
|
@ -207,7 +207,7 @@ class MMStorage {
|
|||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
static constexpr size_t kMaxKC = 8 * 1024;
|
||||
|
||||
MMStorage(const Allocator2& allocator, MMParallel& parallel)
|
||||
MMStorage(const Allocator& allocator, MMParallel& parallel)
|
||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||
// one instance of the maximum matrix extents because threads write at
|
||||
// false-sharing-free granularity.
|
||||
|
|
@ -236,7 +236,7 @@ class MMStorage {
|
|||
|
||||
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
|
||||
// non-const, because `RowPtr` requires a non-const pointer.
|
||||
RowPtrBF A(const Allocator2& allocator, size_t pkg_idx,
|
||||
RowPtrBF A(const Allocator& allocator, size_t pkg_idx,
|
||||
const Extents2D& extents) {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
|
|
@ -430,7 +430,7 @@ class MMConfig {
|
|||
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||
#pragma pack(pop)
|
||||
|
||||
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
||||
std::vector<MMConfig> MMCandidates(const Allocator& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
|
|
@ -561,7 +561,7 @@ class MMKeys {
|
|||
}
|
||||
|
||||
// Must only be called if not already present in `Keys()`.
|
||||
void Append(Key key, const Allocator2& allocator) {
|
||||
void Append(Key key, const Allocator& allocator) {
|
||||
// Dynamic allocation because the test checks many more dimensions than
|
||||
// would be reasonable to pre-allocate. DIY for alignment and padding.
|
||||
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
|
||||
|
|
@ -608,9 +608,9 @@ struct MMPerKey {
|
|||
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
|
||||
// `MatMulEnv`.
|
||||
struct MatMulEnv {
|
||||
explicit MatMulEnv(ThreadingContext2& ctx);
|
||||
explicit MatMulEnv(ThreadingContext& ctx);
|
||||
|
||||
ThreadingContext2& ctx;
|
||||
ThreadingContext& ctx;
|
||||
bool have_timer_stop = false;
|
||||
|
||||
// Whether `MMCandidates()` should print the set of parameters.
|
||||
|
|
@ -753,7 +753,7 @@ ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
|
|||
}
|
||||
|
||||
template <typename TB>
|
||||
void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC,
|
||||
void BindB(const Allocator& allocator, size_t N, size_t sizeof_TC,
|
||||
const ConstMat<TB>& B, MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
|
|
|
|||
|
|
@ -86,7 +86,7 @@ float MaxAbs(const RowVectorBatch<float>& a) {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t cols = A.extents.cols;
|
||||
const size_t B_rows = B.extents.rows;
|
||||
|
|
@ -210,7 +210,7 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
|||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env, int line) {
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
||||
|
|
@ -259,12 +259,12 @@ void TestTiny() {
|
|||
if (HWY_TARGET != first_target) return;
|
||||
|
||||
for (size_t max_packages : {1, 2}) {
|
||||
ThreadingContext2::ThreadHostileInvalidate();
|
||||
ThreadingContext::ThreadHostileInvalidate();
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.bind = Tristate::kTrue;
|
||||
threading_args.max_packages = max_packages;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext2::Get());
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext::Get());
|
||||
NestedPools& pools = env.ctx.pools;
|
||||
|
||||
#if GEMMA_DISABLE_TOPOLOGY
|
||||
|
|
@ -296,11 +296,11 @@ void TestAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
ThreadingContext2::ThreadHostileInvalidate();
|
||||
ThreadingContext::ThreadHostileInvalidate();
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.bind = Tristate::kTrue;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext2::Get());
|
||||
ThreadingContext::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext::Get());
|
||||
NestedPools& pools = env.ctx.pools;
|
||||
pools.MaybeStartSpinning(threading_args.spin);
|
||||
|
||||
|
|
|
|||
|
|
@ -808,7 +808,7 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
// Each output row is the average of a 4x4 block of input rows
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
const Extents2D extents = input.Extents();
|
||||
// Input validation
|
||||
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@
|
|||
namespace gcpp {
|
||||
|
||||
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
|
||||
const Allocator2& allocator, size_t qkv_dim, bool half_rope,
|
||||
const Allocator& allocator, size_t qkv_dim, bool half_rope,
|
||||
double base_frequency = 10000.0) {
|
||||
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));
|
||||
|
|
|
|||
|
|
@ -386,7 +386,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
}
|
||||
|
||||
void TestRopeAndMulBy() {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
|
||||
ModelConfig config(Model::GEMMA2_9B, Type::kSFP,
|
||||
ChooseWrapping(Model::GEMMA2_9B));
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
const Allocator2& allocator = s_env->Env().ctx.allocator;
|
||||
const Allocator& allocator = s_env->Env().ctx.allocator;
|
||||
Gemma& gemma = *(s_env->GetGemma());
|
||||
image_tokens_ = ImageTokens(
|
||||
allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len,
|
||||
|
|
|
|||
|
|
@ -168,7 +168,7 @@ class GemmaModel {
|
|||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
const gcpp::Gemma& gemma = *gemma_.GetGemma();
|
||||
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
|
||||
const gcpp::Allocator& allocator = gemma_.Env().ctx.allocator;
|
||||
if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA &&
|
||||
gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) {
|
||||
throw std::invalid_argument("Not a PaliGemma model.");
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ size_t DetectTotalMiB(size_t page_bytes) {
|
|||
|
||||
} // namespace
|
||||
|
||||
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
|
||||
Allocator::Allocator(const BoundedTopology& topology, bool enable_bind) {
|
||||
line_bytes_ = DetectLineBytes();
|
||||
vector_bytes_ = hwy::VectorBytes();
|
||||
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
||||
|
|
@ -180,7 +180,7 @@ Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
|
|||
quantum_step_mask_ = quantum_bytes_ / step_bytes_ - 1;
|
||||
}
|
||||
|
||||
size_t Allocator2::FreeMiB() const {
|
||||
size_t Allocator::FreeMiB() const {
|
||||
#if HWY_OS_LINUX
|
||||
const long ret = sysconf(_SC_AVPHYS_PAGES); // NOLINT(runtime/int)
|
||||
HWY_ASSERT(ret != -1);
|
||||
|
|
@ -201,7 +201,7 @@ size_t Allocator2::FreeMiB() const {
|
|||
#endif
|
||||
}
|
||||
|
||||
AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const {
|
||||
AlignedPtr2<uint8_t[]> Allocator::AllocBytes(size_t bytes) const {
|
||||
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
||||
// defends against 2K aliasing.
|
||||
if (!should_bind_) {
|
||||
|
|
@ -296,7 +296,7 @@ size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
|||
return num_busy;
|
||||
}
|
||||
|
||||
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
||||
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
||||
HWY_DASSERT(should_bind_);
|
||||
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
||||
|
||||
|
|
@ -353,7 +353,7 @@ bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
|||
}
|
||||
|
||||
#else
|
||||
bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; }
|
||||
bool Allocator::BindMemory(void*, size_t, size_t) const { return false; }
|
||||
#endif // GEMMA_BIND && HWY_OS_LINUX
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -78,14 +78,14 @@ template <typename T>
|
|||
using AlignedClassPtr2 = std::unique_ptr<T, DeleterDtor2>;
|
||||
|
||||
// Both allocation, binding, and row accessors depend on the sizes of memory
|
||||
// pages and cache lines. To avoid having to pass `Allocator2&` everywhere, we
|
||||
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
|
||||
// wrap this in a singleton. A monostate requires explicit initialization,
|
||||
// which we prefer to avoid because there are many main() functions.
|
||||
class Allocator2 {
|
||||
class Allocator {
|
||||
public:
|
||||
// Must be called at least once before any other function. Not thread-safe,
|
||||
// hence only call this from the main thread.
|
||||
Allocator2(const BoundedTopology& topology, bool enable_bind);
|
||||
Allocator(const BoundedTopology& topology, bool enable_bind);
|
||||
|
||||
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
||||
// ranges such that there will be no false sharing.
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
|
|||
return padded_num;
|
||||
}
|
||||
|
||||
static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
|
||||
static size_t Stride(const Allocator& allocator, const MatPtr& mat,
|
||||
MatPadding padding) {
|
||||
switch (padding) {
|
||||
case MatPadding::kPacked:
|
||||
|
|
@ -119,7 +119,7 @@ static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
|
|||
|
||||
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||
if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked;
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const Allocator& allocator = ThreadingContext::Get().allocator;
|
||||
const size_t stride = Stride(allocator, mat, padding);
|
||||
const size_t num = mat.Rows() * stride;
|
||||
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
||||
|
|
|
|||
16
util/mat.h
16
util/mat.h
|
|
@ -282,11 +282,11 @@ void ZeroInit(MatPtr& mat);
|
|||
void RandInit(MatPtr& mat, float stddev, std::mt19937& gen);
|
||||
|
||||
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
|
||||
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB.
|
||||
// `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is typically 4KiB.
|
||||
// To avoid remote accesses, we would thus pad each row to that, which results
|
||||
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
|
||||
// by pulling rows forward by a cyclic offset, which is still a multiple of the
|
||||
// cache line size. This requires an additional `Allocator2::QuantumBytes()` of
|
||||
// cache line size. This requires an additional `Allocator::QuantumBytes()` of
|
||||
// padding after also rounding up to that, which considerably increases size for
|
||||
// tall and skinny tensors.
|
||||
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
|
||||
|
|
@ -295,7 +295,7 @@ static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
|
|||
// Constexpr version (upper bound) for allocating storage in MatMul.
|
||||
template <typename T>
|
||||
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
|
||||
constexpr size_t quantum = Allocator2::MaxQuantum<T>();
|
||||
constexpr size_t quantum = Allocator::MaxQuantum<T>();
|
||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||
}
|
||||
|
||||
|
|
@ -387,7 +387,7 @@ MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
|
|||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols,
|
||||
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols,
|
||||
size_t stride)
|
||||
: row0_(row0),
|
||||
stride_(stride),
|
||||
|
|
@ -414,7 +414,7 @@ class RowPtr {
|
|||
}
|
||||
}
|
||||
|
||||
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols)
|
||||
RowPtr(const Allocator& allocator, T* HWY_RESTRICT row0, size_t cols)
|
||||
: RowPtr(allocator, row0, cols, cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const {
|
||||
|
|
@ -480,7 +480,7 @@ class RowVectorBatch {
|
|||
// we default to tightly packed rows (`stride = cols`).
|
||||
// WARNING: not all call sites support `stride` != cols.
|
||||
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
||||
RowVectorBatch(const Allocator2& allocator, Extents2D extents,
|
||||
RowVectorBatch(const Allocator& allocator, Extents2D extents,
|
||||
size_t stride = 0)
|
||||
: extents_(extents) {
|
||||
if (stride == 0) {
|
||||
|
|
@ -529,14 +529,14 @@ class RowVectorBatch {
|
|||
};
|
||||
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromBatch(const Allocator2& allocator,
|
||||
RowPtr<T> RowPtrFromBatch(const Allocator& allocator,
|
||||
RowVectorBatch<T>& row_vectors) {
|
||||
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
|
||||
row_vectors.Stride());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AllocateAlignedRows(const Allocator2& allocator,
|
||||
RowVectorBatch<T> AllocateAlignedRows(const Allocator& allocator,
|
||||
Extents2D extents) {
|
||||
return RowVectorBatch<T>(
|
||||
allocator, extents,
|
||||
|
|
|
|||
|
|
@ -109,7 +109,7 @@ static Pinning& GetPinning() {
|
|||
return pinning;
|
||||
}
|
||||
|
||||
static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers,
|
||||
static PoolPtr MakePool(const Allocator& allocator, size_t num_workers,
|
||||
std::optional<size_t> node = std::nullopt) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
|
|
@ -136,7 +136,7 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
|||
}
|
||||
|
||||
NestedPools::NestedPools(const BoundedTopology& topology,
|
||||
const Allocator2& allocator, size_t max_threads,
|
||||
const Allocator& allocator, size_t max_threads,
|
||||
Tristate pin) {
|
||||
GetPinning().SetPolicy(pin);
|
||||
packages_.resize(topology.NumPackages());
|
||||
|
|
@ -175,7 +175,7 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
|||
}
|
||||
|
||||
NestedPools::Package::Package(const BoundedTopology& topology,
|
||||
const Allocator2& allocator, size_t pkg_idx,
|
||||
const Allocator& allocator, size_t pkg_idx,
|
||||
size_t max_workers_per_package) {
|
||||
// Pre-allocate because elements are set concurrently.
|
||||
clusters_.resize(topology.NumClusters(pkg_idx));
|
||||
|
|
|
|||
|
|
@ -74,7 +74,7 @@ class NestedPools {
|
|||
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
||||
// only impose upper bounds on the number of detected packages and clusters
|
||||
// rather than defining the actual number of threads.
|
||||
NestedPools(const BoundedTopology& topology, const Allocator2& allocator,
|
||||
NestedPools(const BoundedTopology& topology, const Allocator& allocator,
|
||||
size_t max_threads = 0, Tristate pin = Tristate::kDefault);
|
||||
|
||||
bool AllPinned() const { return all_pinned_; }
|
||||
|
|
@ -148,7 +148,7 @@ class NestedPools {
|
|||
class Package {
|
||||
public:
|
||||
Package() = default; // for vector
|
||||
Package(const BoundedTopology& topology, const Allocator2& allocator,
|
||||
Package(const BoundedTopology& topology, const Allocator& allocator,
|
||||
size_t pkg_idx, size_t max_workers_per_package);
|
||||
|
||||
size_t NumClusters() const { return clusters_.size(); }
|
||||
|
|
|
|||
|
|
@ -26,37 +26,37 @@ namespace gcpp {
|
|||
static ThreadingArgs s_args;
|
||||
// Cannot use magic static because that does not support `Invalidate`, hence
|
||||
// allocate manually.
|
||||
static std::unique_ptr<ThreadingContext2> s_ctx;
|
||||
static std::unique_ptr<ThreadingContext> s_ctx;
|
||||
static std::mutex s_ctx_mutex;
|
||||
|
||||
/*static*/ void ThreadingContext2::SetArgs(const ThreadingArgs& args) {
|
||||
/*static*/ void ThreadingContext::SetArgs(const ThreadingArgs& args) {
|
||||
s_ctx_mutex.lock();
|
||||
HWY_ASSERT(!s_ctx); // Ensure not already initialized, else this is too late.
|
||||
s_args = args;
|
||||
s_ctx_mutex.unlock();
|
||||
}
|
||||
|
||||
/*static*/ bool ThreadingContext2::IsInitialized() {
|
||||
/*static*/ bool ThreadingContext::IsInitialized() {
|
||||
s_ctx_mutex.lock();
|
||||
const bool initialized = !!s_ctx;
|
||||
s_ctx_mutex.unlock();
|
||||
return initialized;
|
||||
}
|
||||
|
||||
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
|
||||
/*static*/ ThreadingContext& ThreadingContext::Get() {
|
||||
PROFILER_FUNC;
|
||||
// We do not bother with double-checked locking because it requires an
|
||||
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
||||
// callers can cache the result and call less often.
|
||||
s_ctx_mutex.lock();
|
||||
if (HWY_UNLIKELY(!s_ctx)) {
|
||||
s_ctx = std::make_unique<ThreadingContext2>(PrivateToken());
|
||||
s_ctx = std::make_unique<ThreadingContext>(PrivateToken());
|
||||
}
|
||||
s_ctx_mutex.unlock();
|
||||
return *s_ctx;
|
||||
}
|
||||
|
||||
/*static*/ void ThreadingContext2::ThreadHostileInvalidate() {
|
||||
/*static*/ void ThreadingContext::ThreadHostileInvalidate() {
|
||||
// Deliberately avoid taking the lock so that tsan can warn if this is
|
||||
// called concurrently with other calls to `Get`.
|
||||
s_ctx.reset();
|
||||
|
|
@ -64,7 +64,7 @@ static std::mutex s_ctx_mutex;
|
|||
|
||||
// WARNING: called with `s_ctx_mutex` held. Calling `SetArgs` or `Get` would
|
||||
// deadlock.
|
||||
ThreadingContext2::ThreadingContext2(ThreadingContext2::PrivateToken)
|
||||
ThreadingContext::ThreadingContext(ThreadingContext::PrivateToken)
|
||||
: topology(BoundedSlice(s_args.skip_packages, s_args.max_packages),
|
||||
BoundedSlice(s_args.skip_clusters, s_args.max_clusters),
|
||||
BoundedSlice(s_args.skip_lps, s_args.max_lps)),
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
|
||||
// Lazily-initialized singleton with support for passing in arguments from
|
||||
// `ThreadingArgs` and re-initializing with different arguments.
|
||||
class ThreadingContext2 {
|
||||
class ThreadingContext {
|
||||
struct PrivateToken {}; // avoids constructing directly
|
||||
|
||||
public:
|
||||
|
|
@ -112,7 +112,7 @@ class ThreadingContext2 {
|
|||
// hence we prefer not to pull `std::shared_ptr` into the interface.
|
||||
//
|
||||
// To reduce overhead, callers should cache the result and call less often.
|
||||
static ThreadingContext2& Get();
|
||||
static ThreadingContext& Get();
|
||||
|
||||
// Invalidates the singleton before or after a call to `Get`. This allows
|
||||
// changing the arguments between tests. Callers must again call `Get`
|
||||
|
|
@ -121,10 +121,10 @@ class ThreadingContext2 {
|
|||
// Also useful to suppress memory leak warnings in tests.
|
||||
static void ThreadHostileInvalidate();
|
||||
|
||||
explicit ThreadingContext2(PrivateToken); // only called via `Get`.
|
||||
explicit ThreadingContext(PrivateToken); // only called via `Get`.
|
||||
|
||||
BoundedTopology topology;
|
||||
Allocator2 allocator;
|
||||
Allocator allocator;
|
||||
NestedPools pools;
|
||||
};
|
||||
|
||||
|
|
|
|||
|
|
@ -383,7 +383,7 @@ TEST(ThreadingTest, BenchJoin) {
|
|||
}
|
||||
};
|
||||
|
||||
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||
NestedPools& pools = ThreadingContext::Get().pools;
|
||||
// Use last package because the main thread has been pinned to it.
|
||||
const size_t pkg_idx = pools.NumPackages() - 1;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue