From 87a658b1c66a6979f474d457a6ea11aa4e4dc377 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 16 Apr 2025 10:48:56 -0700 Subject: [PATCH] Minor cleanup, on-demand NUQ buffer allocation threading_context: add profiler compress-inl: add constexpr, on-demand alloc NUQ buffer gemma_py: model->gemma Move ScaleWeights to compress.cc Move PromptWrapping to configs.h PiperOrigin-RevId: 748347896 --- .github/workflows/build.yml | 2 +- BUILD.bazel | 4 ++++ compression/BUILD.bazel | 5 ++-- compression/compress-inl.h | 22 ++++++++++------- compression/compress.cc | 28 +++++++++++++++++++++- compression/compress.h | 20 ++++++++-------- compression/shared.h | 41 +------------------------------- gemma/configs.h | 14 +++++++++++ ops/dot-inl.h | 3 --- paligemma/BUILD.bazel | 1 + paligemma/paligemma_test.cc | 19 ++++++++------- python/BUILD.bazel | 2 +- python/configs.cc | 1 + python/gemma_py.cc | 23 +++++++++--------- util/mat.cc | 47 +++++++++++++++++++++++++++++++++++++ util/mat.h | 46 ++++++++++++++++++++++++++++++------ util/threading_context.cc | 4 ++++ 17 files changed, 188 insertions(+), 94 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index b796814..7d4496a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -82,7 +82,7 @@ jobs: subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/gemma", "/kaggle/working"]) subprocess.run(["chmod", "700", "/kaggle/working/gemma"]) subprocess.run(["cp", "/kaggle/input/gemma-build-artifacts/_deps/sentencepiece-build/src/libsentencepiece.so.0", "/kaggle/working"]) - output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--compressed_weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout + output = subprocess.run(["/kaggle/working/gemma", "--tokenizer", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/tokenizer.spm", "--weights", "/kaggle/input/gemma/gemmacpp/2b-it-sfp/4/2b-it-sfp.sbs", "--model", "2b-it", "--verbosity", "0", "--max_generated_tokens", "128"], stdout=subprocess.PIPE, input='Write an email to the moon.', encoding='ascii').stdout assert("write an email to the moon." not in output.lower()); assert("moon" in output.lower()); EOF diff --git a/BUILD.bazel b/BUILD.bazel index ad37b4c..2a8f4ab 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -92,6 +92,8 @@ cc_library( ":basics", ":threading", ":topology", + "@highway//:hwy", + "@highway//:profiler", ], ) @@ -180,6 +182,7 @@ cc_library( "//compression:shared", "@highway//:hwy", "@highway//:profiler", + "@highway//:thread_pool", ], ) @@ -664,6 +667,7 @@ cc_test( ":mat", ":prompt", ":sampler", + ":threading_context", ":weights", "@googletest//:gtest_main", # buildcleaner: keep "@highway//:thread_pool", diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index e58b61c..8fb2864 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -70,6 +70,7 @@ cc_library( hdrs = ["blob_store.h"], deps = [ ":io", + "//:basics", "//:threading_context", "@highway//:hwy", "@highway//:thread_pool", @@ -130,7 +131,6 @@ cc_library( textual_hdrs = ["sfp-inl.h"], deps = [ ":shared", - "//:basics", "@highway//:hwy", ], ) @@ -195,7 +195,6 @@ cc_test( deps = [ ":distortion", ":nuq", - ":sfp", "@googletest//:gtest_main", # buildcleaner: keep "//:test_util", "@highway//:hwy", @@ -225,6 +224,7 @@ cc_library( "//:mat", "@highway//:hwy", "@highway//:nanobenchmark", + "@highway//:profiler", "@highway//:stats", "@highway//:thread_pool", ], @@ -259,6 +259,7 @@ cc_library( deps = [ ":nuq", ":sfp", + ":shared", "@highway//:hwy", "@highway//:stats", "@highway//:thread_pool", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 0c0cdef..d4849dc 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -21,8 +21,7 @@ #include #include -#include // lroundf, only if COMPRESS_STATS -#include +#include #include #include "compression/blob_store.h" @@ -35,6 +34,10 @@ #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" +#if COMPRESS_STATS +#include // lroundf +#endif + #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_INL_H_ // Include guard for (potentially) SIMD code. @@ -388,7 +391,7 @@ struct CompressTraits { const size_t packed_ofs) { SfpCodec::Enc(df, raw, num, packed.ptr + packed_ofs); - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { const hn::Repartition dbf; auto distorted = hwy::AllocateAligned(hwy::RoundUpTo(num, hn::Lanes(dbf))); @@ -432,9 +435,10 @@ struct CompressTraits { size_t num, CompressPerThread& tls, const PackedSpan& packed, const size_t packed_ofs) { - NuqCodec::Enc(df, raw, num, tls.buf, packed, packed_ofs); + if (!tls.buf) tls.buf = std::make_unique(); + NuqCodec::Enc(df, raw, num, *tls.buf, packed, packed_ofs); - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { for (size_t i = 0; i < num; ++i) { tls.stats.NotifyIn(static_cast(lroundf(raw[i] * 100.0f + 500.0f))); } @@ -478,7 +482,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, const size_t packed_ofs, hwy::ThreadPool& pool) { packed.BoundsCheck(packed_ofs, num); work.tls.resize(pool.NumWorkers()); - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { for (auto& tls : work.tls) { tls.stats.Reset(); } @@ -487,7 +491,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, const bool want_bench = COMPRESS_STATS || !kIsTest; const double t0 = want_bench ? hwy::platform::Now() : 0.0; - using Traits = CompressTraits; + using Traits = CompressTraits>; constexpr size_t kBatch = 8192; const size_t num_batches = hwy::DivCeil(num, kBatch); pool.Run(0, num_batches, @@ -508,7 +512,7 @@ HWY_NOINLINE void Compress(const float* HWY_RESTRICT raw, size_t num, fprintf(stderr, "Compress %.1f MB/s\n", mbps); } - if (COMPRESS_STATS) { + if constexpr (COMPRESS_STATS) { for (size_t i = 1; i < work.tls.size(); ++i) { work.tls[0].stats.Assimilate(work.tls[i].stats); } @@ -534,7 +538,7 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan& packed, const size_t packed_ofs) { static_assert(hwy::IsSameEither()); packed.BoundsCheck(packed_ofs, 2 * hn::Lanes(df)); - using Traits = CompressTraits; + using Traits = CompressTraits>; Traits::Store2(df, raw0, raw1, packed, packed_ofs); } diff --git a/compression/compress.cc b/compression/compress.cc index 1818b8f..6ef8990 100644 --- a/compression/compress.cc +++ b/compression/compress.cc @@ -15,8 +15,34 @@ #include "compression/compress.h" +#include +#include + +#include "util/mat.h" +#include "hwy/base.h" +#include "hwy/profiler.h" + namespace gcpp { -// TODO: move ScaleWeights here. +float ScaleWeights(float* HWY_RESTRICT raw, size_t num) { + PROFILER_FUNC; + + float maxabs = 0.0; + for (size_t i = 0; i < num; ++i) { + maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i])); + } + if (maxabs <= SfpStream::kMax) { + return 1.0f; + } + const float scale = maxabs / SfpStream::kMax; + const float inv_scale = static_cast(1.0 / static_cast(scale)); + for (size_t i = 0; i < num; ++i) { + // Clamp because kMax may still be exceeded. + const float magn = + HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale)); + raw[i] = hwy::ScalarCopySign(magn, raw[i]); + } + return scale; +} } // namespace gcpp diff --git a/compression/compress.h b/compression/compress.h index 8844601..2a5df9d 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -17,26 +17,19 @@ #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ -#include "hwy/base.h" #define COMPRESS_STATS 0 #include #include #include -#include -#include -#include -#include -#include +#include #include -// IWYU pragma: begin_exports #include "compression/blob_store.h" #include "compression/fields.h" #include "compression/io.h" -#include "compression/shared.h" -#include "gemma/tensor_index.h" +#include "compression/shared.h" // NuqStream::ClusterBuf #include "util/basics.h" // IWYU pragma: end_exports #include "gemma/configs.h" @@ -174,7 +167,8 @@ struct CompressStats { #endif // COMPRESS_STATS struct CompressPerThread { - NuqStream::ClusterBuf buf; + // Allocated the first time NUQ is used. + std::unique_ptr buf; CompressStats stats; }; @@ -375,5 +369,11 @@ class ReadFromBlobStore { std::vector file_keys_; }; +// Returns 1.0f if all magnitudes are <= `SfpStream::kMax`, otherwise scales +// them such that the largest magnitude is `SfpStream::kMax`, and returns the +// multiplier with which to restore the original values. This is only necessary +// before compressing to `SfpStream` and `NuqStream`. +float ScaleWeights(float* HWY_RESTRICT raw, size_t num); + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_COMPRESS_H_ diff --git a/compression/shared.h b/compression/shared.h index 8b6fb82..27e998d 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -13,8 +13,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Definitions shared between the public compress-inl.h interface and the -// sfp-inl.h and nuq-inl.h implementation details. +// Types shared between tensor definitions and `compress-inl.h`. #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ @@ -63,30 +62,6 @@ struct SfpStream { }; #pragma pack(pop) -// Returns 1.0f if all magnitudes are <= SfpStream::kMax, otherwise scales them -// such that the largest magnitude is SfpStream::kMax, and returns the -// multiplier with which to restore the original values. This is only necessary -// before compressing to SfpStream. -// TODO: vectorize -static inline float ScaleWeights(float* HWY_RESTRICT raw, size_t num) { - float maxabs = 0.0; - for (size_t i = 0; i < num; ++i) { - maxabs = HWY_MAX(maxabs, hwy::ScalarAbs(raw[i])); - } - if (maxabs <= SfpStream::kMax) { - return 1.0f; - } - const float scale = maxabs / SfpStream::kMax; - const float inv_scale = static_cast(1.0 / static_cast(scale)); - for (size_t i = 0; i < num; ++i) { - // Clamp because kMax may still be exceeded. - const float magn = - HWY_MIN(SfpStream::kMax, hwy::ScalarAbs(raw[i] * inv_scale)); - raw[i] = hwy::ScalarCopySign(magn, raw[i]); - } - return scale; -} - // Non-uniform quantization: a compressed representation of f32 inputs that // supports seeking at a granularity of 1 (for `DecompressAndZeroPad`) or // two vectors (for `Decompress2`), and decoding to bf16/f32. @@ -185,20 +160,6 @@ constexpr bool IsNuqStream() { return hwy::IsSame, NuqStream>(); } -// Instruction-tuned models require extra 'turn structure' tokens in prompts. -enum class PromptWrapping { - GEMMA_IT, - GEMMA_PT, - GEMMA_VLM, - PALIGEMMA, - kSentinel // must be last -}; - -inline bool EnumValid(PromptWrapping type) { - return static_cast(type) >= 0 && - static_cast(type) < static_cast(PromptWrapping::kSentinel); -} - // Tensor types for loading weights. Note that not all types are supported as // weights for a model, but can be used for other purposes, such as types for // `WeightsPtrs`. When adding a new type that is supported, also diff --git a/gemma/configs.h b/gemma/configs.h index 837e067..77d063a 100644 --- a/gemma/configs.h +++ b/gemma/configs.h @@ -49,6 +49,20 @@ static constexpr size_t kMaxConv1DWidth = 4; using EmbedderInputT = BF16; +// Instruction-tuned models require extra 'turn structure' tokens in prompts. +enum class PromptWrapping { + GEMMA_IT, + GEMMA_PT, + GEMMA_VLM, + PALIGEMMA, + kSentinel // must be last +}; + +static inline bool EnumValid(PromptWrapping wrapping) { + return static_cast(wrapping) < + static_cast(PromptWrapping::kSentinel); +} + enum class LayerAttentionType { kGemma, kGriffinRecurrentBlock, diff --git a/ops/dot-inl.h b/ops/dot-inl.h index 08a5ca8..36bec6a 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -15,9 +15,6 @@ #include -#include "compression/compress.h" -#include "util/mat.h" -#include "hwy/base.h" #include "hwy/profiler.h" // Include guard for (potentially) SIMD code. diff --git a/paligemma/BUILD.bazel b/paligemma/BUILD.bazel index 069fd6b..36e59d3 100644 --- a/paligemma/BUILD.bazel +++ b/paligemma/BUILD.bazel @@ -40,6 +40,7 @@ cc_test( ], deps = [ "@googletest//:gtest_main", # buildcleaner: keep + "//:allocator", "//:benchmark_helper", "//:common", "//:gemma_lib", diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index 398b067..2453822 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -20,7 +20,9 @@ #include "compression/shared.h" #include "evals/benchmark_helper.h" #include "gemma/common.h" +#include "gemma/configs.h" #include "gemma/gemma.h" +#include "util/allocator.h" #include "hwy/base.h" #include "hwy/tests/hwy_gtest.h" @@ -50,17 +52,18 @@ class PaliGemmaTest : public ::testing::Test { void PaliGemmaTest::InitVit(const std::string& path) { ASSERT_NE(s_env->GetGemma(), nullptr); - Gemma& model = *(s_env->GetGemma()); - image_tokens_ = - ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len, - model.GetModelConfig().model_dim)); + const Allocator2& allocator = s_env->Env().ctx.allocator; + Gemma& gemma = *(s_env->GetGemma()); + image_tokens_ = ImageTokens( + allocator, Extents2D(gemma.GetModelConfig().vit_config.seq_len, + gemma.GetModelConfig().model_dim)); Image image; - HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA); + HWY_ASSERT(gemma.GetModelConfig().wrapping == PromptWrapping::PALIGEMMA); HWY_ASSERT(image.ReadPPM(path)); - const size_t image_size = model.GetModelConfig().vit_config.image_size; + const size_t image_size = gemma.GetModelConfig().vit_config.image_size; image.Resize(image_size, image_size); RuntimeConfig runtime_config = {.gen = &s_env->MutableGen(), .verbosity = 0}; - model.GenerateImageTokens(runtime_config, image, image_tokens_); + gemma.GenerateImageTokens(runtime_config, image, image_tokens_); } std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ @@ -124,7 +127,7 @@ TEST_F(PaliGemmaTest, General) { }; const char* (*qa)[2]; size_t num; - switch (s_env->GetGemma()->Info().model) { + switch (s_env->GetGemma()->GetModelConfig().model) { case Model::PALIGEMMA_224: qa = kQA_3B_mix_224; num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]); diff --git a/python/BUILD.bazel b/python/BUILD.bazel index 2a7220a..9ae2c31 100644 --- a/python/BUILD.bazel +++ b/python/BUILD.bazel @@ -21,10 +21,10 @@ pybind_extension( name = "gemma", srcs = ["gemma_py.cc"], deps = [ - "//:allocator", "//:benchmark_helper", "//:gemma_args", "//:gemma_lib", + "//:threading_context", "//compression:shared", "@highway//:hwy", ], diff --git a/python/configs.cc b/python/configs.cc index 53ba5c4..b24d5cd 100644 --- a/python/configs.cc +++ b/python/configs.cc @@ -43,6 +43,7 @@ PYBIND11_MODULE(configs, py_module) { enum_(py_module, "PromptWrapping") .value("GEMMA_IT", PromptWrapping::GEMMA_IT) .value("GEMMA_PT", PromptWrapping::GEMMA_PT) + .value("GEMMA_VLM", PromptWrapping::GEMMA_VLM) .value("PALIGEMMA", PromptWrapping::PALIGEMMA); enum_(py_module, "Type") diff --git a/python/gemma_py.cc b/python/gemma_py.cc index 0791188..90861d9 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -22,18 +22,16 @@ #include #include -#include #include #include #include #include #include -#include "compression/shared.h" #include "evals/benchmark_helper.h" #include "gemma/gemma.h" #include "gemma/gemma_args.h" -#include "util/allocator.h" +#include "util/threading_context.h" #include "hwy/base.h" namespace py = pybind11; @@ -169,9 +167,10 @@ class GemmaModel { // Generate* will use this image. Throws an error for other models. void SetImage(const py::array_t& image) { + gcpp::Gemma& gemma = *(gemma_.GetGemma()); const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator; - gcpp::Gemma& model = *(gemma_.GetGemma()); - if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) { + if (gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::PALIGEMMA && + gemma.GetModelConfig().wrapping != gcpp::PromptWrapping::GEMMA_VLM) { throw std::invalid_argument("Not a PaliGemma model."); } py::buffer_info buffer = image.request(); @@ -183,14 +182,14 @@ class GemmaModel { float* ptr = static_cast(buffer.ptr); gcpp::Image c_image; c_image.Set(height, width, ptr); - const size_t image_size = model.GetModelConfig().vit_config.image_size; + const size_t image_size = gemma.GetModelConfig().vit_config.image_size; c_image.Resize(image_size, image_size); image_tokens_ = gcpp::ImageTokens( - allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len, - model.GetModelConfig().model_dim)); + allocator, gcpp::Extents2D(gemma.GetModelConfig().vit_config.seq_len, + gemma.GetModelConfig().model_dim)); gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), .verbosity = 0}; - model.GenerateImageTokens(runtime_config, c_image, image_tokens_); + gemma.GenerateImageTokens(runtime_config, c_image, image_tokens_); } // Generates a response to the given prompt, using the last set image. @@ -267,12 +266,12 @@ PYBIND11_MODULE(gemma, mod) { throw std::invalid_argument(err); } loader.weight_type_str = weight_type; + gcpp::ThreadingArgs threading; + threading.max_lps = max_threads; gcpp::InferenceArgs inference; inference.max_generated_tokens = 512; - gcpp::ThreadingArgs app; - app.max_threads = max_threads; auto gemma = - std::make_unique(loader, inference, app); + std::make_unique(loader, inference, threading); if (!gemma->ModelIsLoaded()) { throw std::invalid_argument("Could not load model."); } diff --git a/util/mat.cc b/util/mat.cc index 677e928..3ce57f3 100644 --- a/util/mat.cc +++ b/util/mat.cc @@ -18,8 +18,12 @@ #include #include +#include +#include + #include "util/threading_context.h" #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/per_target.h" // VectorBytes #include "hwy/profiler.h" @@ -27,8 +31,11 @@ namespace gcpp { void CopyMat(const MatPtr& from, MatPtr& to) { PROFILER_FUNC; + HWY_ASSERT_M(from.HasPtr() && to.HasPtr(), to.Name()); HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols()); HWY_ASSERT(to.GetType() == from.GetType()); + to.SetScale(from.Scale()); + if (to.IsPacked() && from.IsPacked()) { HWY_ASSERT(to.PackedBytes() == from.PackedBytes()); hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes()); @@ -45,6 +52,8 @@ void CopyMat(const MatPtr& from, MatPtr& to) { void ZeroInit(MatPtr& mat) { PROFILER_FUNC; HWY_ASSERT_M(mat.HasPtr(), mat.Name()); + mat.SetScale(1.0f); + if (mat.IsPacked()) { hwy::ZeroBytes(mat.Packed(), mat.PackedBytes()); return; @@ -55,6 +64,31 @@ void ZeroInit(MatPtr& mat) { } } +void RandInit(MatPtr& mat, float stddev, std::mt19937& gen) { + PROFILER_FUNC; + HWY_ASSERT_M(mat.HasPtr(), mat.Name()); + // Only generates float/double for use by backprop/. + HWY_ASSERT(mat.GetType() == Type::kF32 || mat.GetType() == Type::kF64); + mat.SetScale(1.0f); + + std::normal_distribution dist(0.0, stddev); + if (mat.GetType() == Type::kF32) { + for (size_t r = 0; r < mat.Rows(); ++r) { + float* HWY_RESTRICT row = mat.RowT(r); + for (size_t c = 0; c < mat.Cols(); ++c) { + row[c] = dist(gen); + } + } + } else { + for (size_t r = 0; r < mat.Rows(); ++r) { + double* HWY_RESTRICT row = mat.RowT(r); + for (size_t c = 0; c < mat.Cols(); ++c) { + row[c] = dist(gen); + } + } + } +} + // Returns `num` rounded up to an odd number of cache lines. This would also // prevent 4K aliasing and is coprime with the cache associativity, which // might reduce conflict misses, but we instead use `StrideForCyclicOffsets`. @@ -84,6 +118,7 @@ static size_t Stride(const Allocator2& allocator, const MatPtr& mat, } void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { + if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked; const Allocator2& allocator = ThreadingContext2::Get().allocator; const size_t stride = Stride(allocator, mat, padding); const size_t num = mat.Rows() * stride; @@ -97,4 +132,16 @@ void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { storage_ = allocator.AllocBytes(padded_bytes); mat.SetPtr(storage_.get(), stride); } + +void MatOwners::AllocateFor(const std::vector& mats, + MatPadding padding, hwy::ThreadPool& pool) { + const size_t start = owners_.size(); + owners_.resize(start + mats.size()); + + // Allocate in parallel because faulting in large tensors is slow. + pool.Run(0, mats.size(), [&](uint64_t task, size_t /*thread*/) { + owners_[start + task].AllocateFor(*mats[task], padding); + }); +} + } // namespace gcpp diff --git a/util/mat.h b/util/mat.h index cbe37a3..d1c7c9d 100644 --- a/util/mat.h +++ b/util/mat.h @@ -22,6 +22,7 @@ #include #include +#include // IWYU pragma: begin_exports #include "compression/fields.h" @@ -31,6 +32,7 @@ #include "util/basics.h" // Extents2D // IWYU pragma: end_exports #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { @@ -71,7 +73,8 @@ class MatPtr : public IFields { bool HasPtr() const { return ptr_ != nullptr; } - bool IsPacked() const { return stride_ == cols_; } + // A single row counts as packed because there is no padding between rows. + bool IsPacked() const { return (stride_ == cols_) || (rows_ == 1); } const void* Packed() const { HWY_DASSERT_M(IsPacked(), name_.c_str()); @@ -132,11 +135,10 @@ class MatPtr : public IFields { float Scale() const { return scale_; } void SetScale(float scale) { scale_ = scale; } - // Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it - // be <= 16 bytes including prefixes/suffixes. The initial name set by the - // ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer - // suffix, and when loading, we call `SetName` with that. + // A terse identifier unique across all tensors of the model. const char* Name() const override { return name_.c_str(); } + // `MakeKey` in `blob_store.cc` requires that this be <= 16 bytes, including + // the `LayerSuffix` for per-layer tensors. void SetName(const char* name) { name_ = name; HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name); @@ -194,11 +196,13 @@ class MatPtr : public IFields { uint32_t stride_; }; -// Non-type erased version of `MatPtr`. Use this when operating on the values. +// Non-type erased version of `MatPtr`. Although `MatPtr` also provides +// type-aware accessors (`RowT`), this class is more convenient when accessing +// elements, and ensures the template argument and `Type` are consistent. template class MatPtrT : public MatPtr { public: - // Runtime-specified shape. + // Called by `MatStorageT`. MatPtrT(const char* name, Extents2D extents) : MatPtr(name, TypeEnum(), extents) {} // Take shape from `TensorInfo` to avoid duplicating it in the caller. @@ -247,6 +251,15 @@ class MatPtrT : public MatPtr { HWY_ASSERT(IsPacked()); return MakeSpan(Row(0), num_elements_); } + + // For when a span of a single row is required. This also works if padded, + // but does not support `GetType() == kNUQ`, because that requires the use of + // offsets instead of a row pointer. Used by `gemma-inl.h` to decompress + // embeddings. + PackedSpan RowSpan(size_t row) const { + HWY_DASSERT(GetType() != Type::kNUQ); + return MakeConstSpan(Row(row), Cols()); + } }; // Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT`, plus the @@ -340,6 +353,25 @@ class MatOwner { AlignedPtr2 storage_; }; +// Multiple `MatOwner`, with support for parallel allocation. +class MatOwners { + public: + // Ignores `padding` for NUQ tensors, which are always packed. + void AllocateFor(MatPtr& mat, MatPadding padding) { + if (mat.GetType() == Type::kNUQ) padding = MatPadding::kPacked; + owners_.push_back(MatOwner()); + owners_.back().AllocateFor(mat, padding); + } + + // Allocates multiple in parallel. Ignores `padding` for NUQ tensors, + // which are always packed. + void AllocateFor(const std::vector& mats, MatPadding padding, + hwy::ThreadPool& pool); + + private: + std::vector owners_; +}; + // `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and // tests to allocate and access tensors of a known type. By contrast, the // heterogeneous model weights are owned by vectors of `MatOwner`. diff --git a/util/threading_context.cc b/util/threading_context.cc index c15e194..2636f46 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -18,6 +18,9 @@ #include #include // NOLINT +#include "hwy/base.h" // HWY_ASSERT, HWY_UNLIKELY +#include "hwy/profiler.h" + namespace gcpp { static ThreadingArgs s_args; @@ -41,6 +44,7 @@ static std::mutex s_ctx_mutex; } /*static*/ ThreadingContext2& ThreadingContext2::Get() { + PROFILER_FUNC; // We do not bother with double-checked locking because it requires an // atomic pointer, but we prefer to use unique_ptr for simplicity. Also, // callers can cache the result and call less often.