From d831ddce5b6b957fb0f8aa8e0b9fefa755f8d046 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 30 Jul 2025 04:29:27 -0700 Subject: [PATCH] Fix file mapping: was letting the smart pointer go out of scope Also save+print the IO mode used. PiperOrigin-RevId: 788848165 --- evals/benchmark_helper.cc | 11 +++-- evals/benchmark_helper.h | 1 + gemma/gemma.cc | 4 +- gemma/gemma.h | 2 + gemma/run.cc | 3 +- gemma/weights.cc | 90 ++++++++++++++++++++------------------- gemma/weights.h | 35 +++++++++++++-- io/blob_store.h | 7 ++- 8 files changed, 95 insertions(+), 58 deletions(-) diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index b4803d0..3b999b4 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -56,7 +56,8 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const ThreadingArgs& threading, kv_caches_.push_back(KVCache(config, inference, ctx_.allocator)); if (inference.verbosity >= 2) { - ShowConfig(loader, threading, inference, config, ctx_); + ShowConfig(loader, threading, inference, config, gemma_.WeightReadMode(), + ctx_); } InitGenerator(inference, gen_); @@ -229,13 +230,15 @@ static constexpr const char* CompiledConfig() { void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference, const ModelConfig& config, + const WeightsPtrs::Mode weight_read_mode, const ThreadingContext& ctx) { threading.Print(inference.verbosity); loader.Print(inference.verbosity); inference.Print(inference.verbosity); - fprintf(stderr, "Model : %s, to_bf16 %d, mmap %d\n", - config.Specifier().c_str(), static_cast(loader.to_bf16), - static_cast(loader.map)); + fprintf( + stderr, "Model : %s, to_bf16 %d, mmap %d => %s\n", + config.Specifier().c_str(), static_cast(loader.to_bf16), + static_cast(loader.map), WeightsPtrs::ToString(weight_read_mode)); if (inference.verbosity >= 2) { time_t now = time(nullptr); diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index a8f0dc8..8f4d96f 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -125,6 +125,7 @@ void LogSpeedStats(double time_start, size_t total_tokens); void ShowConfig(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference, const ModelConfig& config, + WeightsPtrs::Mode weight_read_mode, const ThreadingContext& ctx); void ShowHelp(const LoaderArgs& loader, const ThreadingArgs& threading, const InferenceArgs& inference); diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 66a8433..19d9926 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -609,7 +609,9 @@ Gemma::Gemma(const LoaderArgs& loader, const InferenceArgs& inference, weights_(model_.Config()), chat_template_(model_.Tokenizer(), model_.Config().model), inference_(inference) { - weights_.ReadFromBlobs(model_, reader_, loader, inference, mat_owners_, ctx); + weight_read_mode_ = weights_.ReadFromBlobs(model_, reader_, loader, inference, + mat_owners_, ctx); + // Read everything into memory, or `weights_.mapped_` keeps the mapping alive. reader_.CloseFile(); } diff --git a/gemma/gemma.h b/gemma/gemma.h index b9f4127..5ebd70d 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -239,6 +239,7 @@ class Gemma { const ModelConfig& Config() const { return model_.Config(); } const GemmaTokenizer& Tokenizer() const { return model_.Tokenizer(); } const WeightsPtrs& Weights() const { return weights_; } + WeightsPtrs::Mode WeightReadMode() const { return weight_read_mode_; } const GemmaChatTemplate& ChatTemplate() const { return chat_template_; } const InferenceArgs& Inference() const { return inference_; } @@ -271,6 +272,7 @@ class Gemma { ModelStore model_; std::vector mat_owners_; WeightsPtrs weights_; + WeightsPtrs::Mode weight_read_mode_; GemmaChatTemplate chat_template_; InferenceArgs inference_; }; diff --git a/gemma/run.cc b/gemma/run.cc index 997d06e..7cbc4de 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -285,7 +285,8 @@ void Run(const LoaderArgs& loader, const ThreadingArgs& threading, if (inference.IsInteractive()) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, threading, inference, gemma.Config(), ctx); + ShowConfig(loader, threading, inference, gemma.Config(), + gemma.WeightReadMode(), ctx); std::cout << "\n" << instructions << "\n"; } } diff --git a/gemma/weights.cc b/gemma/weights.cc index b205bd1..3418acf 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -223,7 +223,7 @@ void WeightsPtrs::CopyFrom(const WeightsPtrs& other) { } // For reshaping file tensors to the shape expected by the code. This would -// ideally already happen in the importer. Called by WeightsOwner::Fixup. +// ideally already happen in the importer. Called by `ReadFromBlobs`. void WeightsPtrs::Fixup(std::vector& mat_owners, ThreadingContext& ctx) { // TODO: use 1D parallel-for helper function @@ -251,21 +251,11 @@ std::vector WeightsPtrs::AddTensorDataToWriter( return serialized_mat_ptrs; } -enum class Mode { - // Parallel I/O, decompress to BF16. Best for large batch sizes. - kReadBF16, - // Parallel I/O, insert row-wise padding. Safe default. - kRead, - // Best for large weights relative to available memory, especially for - // frequent invocations of small batches and short sequences. Adds noise to - // performance measurements due to I/O variability. - kMap -}; - // Decides whether to read or map based on heuristics and user override. -static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader, - const InferenceArgs& inference, - const Allocator& allocator) { +static WeightsPtrs::Mode ChooseMode(uint64_t file_bytes, + const LoaderArgs& loader, + const InferenceArgs& inference, + const Allocator& allocator) { Tristate to_bf16 = loader.to_bf16; Tristate map = loader.map; @@ -283,8 +273,8 @@ static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader, if (to_bf16 == Tristate::kTrue && map == Tristate::kTrue) { HWY_WARN("Cannot have to_bf16 && map, to_bf16 takes precedence."); } - if (to_bf16 == Tristate::kTrue) return Mode::kReadBF16; - if (map == Tristate::kTrue) return Mode::kMap; + if (to_bf16 == Tristate::kTrue) return WeightsPtrs::Mode::kReadBF16; + if (map == Tristate::kTrue) return WeightsPtrs::Mode::kMap; if (to_bf16 == Tristate::kDefault) { // Heuristic: sub-bf16 compression is not helpful if compute-bound. @@ -307,8 +297,9 @@ static Mode ChooseMode(uint64_t file_bytes, const LoaderArgs& loader, } // If the `map` heuristic triggers, use that for safety. - if (map == Tristate::kTrue) return Mode::kMap; - return (to_bf16 == Tristate::kTrue) ? Mode::kReadBF16 : Mode::kRead; + if (map == Tristate::kTrue) return WeightsPtrs::Mode::kMap; + return (to_bf16 == Tristate::kTrue) ? WeightsPtrs::Mode::kReadBF16 + : WeightsPtrs::Mode::kRead; } struct TensorToRead { @@ -324,7 +315,8 @@ struct TensorToRead { // Allocates multiple in parallel and binds to NUMA nodes. static void AllocateAndBindAll(std::vector& tensors, - const Mode mode, std::vector& owners, + const WeightsPtrs::Mode mode, + std::vector& owners, ThreadingContext& ctx) { const size_t start = owners.size(); owners.resize(start + tensors.size()); @@ -342,7 +334,7 @@ static void AllocateAndBindAll(std::vector& tensors, if (tensor.prev_type == Type::kF32 || mat.Rows() < 1024) { tensor.keep_type = true; tensor.padding = MatPadding::kPacked; // single I/O for simplicity - } else if (mode == Mode::kReadBF16) { + } else if (mode == WeightsPtrs::Mode::kReadBF16) { mat.SetType(Type::kBF16); } @@ -354,7 +346,7 @@ static void AllocateAndBindAll(std::vector& tensors, // Mode == kMap static void MapAll(const std::vector& tensors, - const MapPtr& mapped) { + const MapPtr& mapped, uint64_t file_bytes) { PROFILER_ZONE("Startup.Weights.Map"); for (size_t i = 0; i < tensors.size(); ++i) { // SetPtr does not change the stride, but it is expected to be packed @@ -362,10 +354,12 @@ static void MapAll(const std::vector& tensors, const size_t mat_bytes = tensors[i].mat->PackedBytes(); // Ensure blob size matches that computed from metadata. HWY_ASSERT_M(mat_bytes == tensors[i].range.bytes, tensors[i].mat->Name()); + // Ensure the blob lies within the file mapping. + const uint64_t offset = tensors[i].range.offset; + HWY_ASSERT_M(offset + mat_bytes <= file_bytes, tensors[i].mat->Name()); - tensors[i].mat->SetPtr( - const_cast(mapped.get() + tensors[i].range.offset), - tensors[i].mat->Stride()); + tensors[i].mat->SetPtr(const_cast(mapped.get() + offset), + tensors[i].mat->Stride()); } } @@ -484,40 +478,49 @@ static void ReadBatches(const BlobReader& reader, }); } -// Aborts on error. -static void MapOrReadAll(std::vector& tensors, BlobReader& reader, - Mode mode, std::vector& mat_owners, - ThreadingContext& ctx) { - if (mode == Mode::kMap) { - MapPtr mapped = reader.file().Map(); - if (mapped) return MapAll(tensors, mapped); +// Aborts on error. Updates `mode` to the actual mode used. Returns mapped +// memory or nullptr if `kMap` was not used. +static MapPtr MapOrReadAll(std::vector& tensors, + BlobReader& reader, WeightsPtrs::Mode* mode, + std::vector& mat_owners, + ThreadingContext& ctx) { + if (*mode == WeightsPtrs::Mode::kMap) { + if (MapPtr mapped = reader.Map()) { + MapAll(tensors, mapped, reader.file().FileSize()); + return mapped; + } HWY_WARN("Failed to map file (%zu KiB), reading instead.", static_cast(reader.file_bytes() >> 10)); // If we wanted to map but failed, memory is probably not plentiful, so // fall through to kRead because kReadBF16 requires more memory. - mode = Mode::kRead; + *mode = WeightsPtrs::Mode::kRead; } { PROFILER_ZONE("Startup.Weights.Allocate"); // NOTE: this changes the stride of `mats`! - AllocateAndBindAll(tensors, mode, mat_owners, ctx); + AllocateAndBindAll(tensors, *mode, mat_owners, ctx); } hwy::ThreadPool& pool = ctx.pools.Pool(); - if (mode == Mode::kReadBF16) return ReadAllToBF16(tensors, reader, pool); + if (*mode == WeightsPtrs::Mode::kReadBF16) { + ReadAllToBF16(tensors, reader, pool); + return MapPtr(); + } const std::vector batches = MakeBatches(tensors, reader.file_bytes()); ReadBatches(reader, batches, pool); + return MapPtr(); } -void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader, - const LoaderArgs& loader, - const InferenceArgs& inference, - std::vector& mat_owners, - ThreadingContext& ctx) { +WeightsPtrs::Mode WeightsPtrs::ReadFromBlobs(const ModelStore& model, + BlobReader& reader, + const LoaderArgs& loader, + const InferenceArgs& inference, + std::vector& mat_owners, + ThreadingContext& ctx) { // List of tensors to read/map, and where from. std::vector tensors; @@ -536,15 +539,14 @@ void WeightsPtrs::ReadFromBlobs(const ModelStore& model, BlobReader& reader, HWY_ABORT("Tensor %s is required but not found in file.", t.mat.Name()); }); - const Mode mode = - ChooseMode(reader.file_bytes(), loader, inference, ctx.allocator); - - MapOrReadAll(tensors, reader, mode, mat_owners, ctx); + Mode mode = ChooseMode(reader.file_bytes(), loader, inference, ctx.allocator); + mapped_ = MapOrReadAll(tensors, reader, &mode, mat_owners, ctx); { PROFILER_ZONE("Startup.Fixup"); Fixup(mat_owners, ctx); } + return mode; } } // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index d9978ff..de3652a 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -90,7 +90,7 @@ class MatFinder { }; // Per-layer weight metadata and pointers. The tensor data is owned by -// `WeightsOwner`. +// `MatOwner`. struct LayerWeightsPtrs { // Initializes tensor metadata without allocating. // NOTE: do not store layer_idx, TransformerLayer and Attention may use @@ -314,7 +314,7 @@ struct LayerWeightsPtrs { }; // Holds layer-independent weight metadata and pointers plus per-layer -// `LayerWeightsPtrs`. The tensor data is owned by `WeightsOwner`. +// `LayerWeightsPtrs`. The tensor data is owned by `MatOwner`. struct WeightsPtrs { explicit WeightsPtrs(const ModelConfig& config) : config_(config), @@ -423,9 +423,34 @@ struct WeightsPtrs { // Copies only the allocated tensors in `*this` from tensors in `other`. void CopyFrom(const WeightsPtrs& other); + enum class Mode { + // Parallel I/O, decompress to BF16. Best for large batch sizes. + kReadBF16, + // Parallel I/O, insert row-wise padding. Safe default. + kRead, + // Best for large weights relative to available memory, especially for + // frequent invocations of small batches and short sequences. Adds noise to + // performance measurements due to I/O variability. + kMap + }; + + static const char* ToString(Mode mode) { + switch (mode) { + case Mode::kReadBF16: + return "ReadBF16"; + case Mode::kRead: + return "Read"; + case Mode::kMap: + return "Map"; + default: + HWY_DASSERT(false); + return "?"; + } + } + // Reads tensor data from `BlobStore` or aborts on error. `map` is a user - // override for whether to map blobs or read them. - void ReadFromBlobs(const ModelStore& model, BlobReader& reader, + // override for whether to map blobs or read them. Returns the mode used. + Mode ReadFromBlobs(const ModelStore& model, BlobReader& reader, const LoaderArgs& loader, const InferenceArgs& inference, std::vector& mat_owners, ThreadingContext& ctx); @@ -436,6 +461,8 @@ struct WeightsPtrs { // For reshaping file tensors to the shape expected by the code. This would // ideally already happen in the importer. Called by ReadFromBlobs. void Fixup(std::vector& mat_owners, ThreadingContext& ctx); + + MapPtr mapped_; }; // `WeightsPtrs` #undef TENSOR_ARGS diff --git a/io/blob_store.h b/io/blob_store.h index f77059a..aa28210 100644 --- a/io/blob_store.h +++ b/io/blob_store.h @@ -47,7 +47,7 @@ struct BlobRange { // Reads `BlobStore` header, converts keys to strings and creates a hash map for // faster lookups. // TODO(janwas): rename to BlobFinder or similar. -// Thread-safe: it is safe to concurrently call all methods. +// Thread-safe: it is safe to concurrently call all methods except `CloseFile`. class BlobReader { public: // Acquires ownership of `file` (which must be non-null) and reads its header. @@ -56,11 +56,10 @@ class BlobReader { const Path& blob_path() const { return blob_path_; } - // Non-const version required for File::Map(). - File& file() { return *file_; } const File& file() const { return *file_; } uint64_t file_bytes() const { return file_bytes_; } - + MapPtr Map() { return file_->Map(); } + // OK to call if Map() was called; the smart pointer keeps the mapping alive. void CloseFile() { file_.reset(); } const std::vector& Keys() const { return keys_; }