From 6e52a835c67f0592204c3dd2ece665534c236d17 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Sun, 7 Sep 2025 22:50:01 -0700 Subject: [PATCH] Faster startup on tsan: use hierarchical parallelism for BF16 conversion Also re-enable profiler zones PiperOrigin-RevId: 804273899 --- gemma/weights.cc | 66 ++++++++++++++++++++++++++---------------------- ops/matmul-inl.h | 7 +++-- 2 files changed, 41 insertions(+), 32 deletions(-) diff --git a/gemma/weights.cc b/gemma/weights.cc index 8191bd9..b71e6b7 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -383,39 +383,45 @@ static void ReadAllToBF16(const std::vector& tensors, const BlobReader& reader, ThreadingContext& ctx) { static const auto zone = ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16"); - ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) { - PROFILER_ZONE3(ctx.profiler, thread, zone); - const TensorToRead& tensor = tensors[task]; - MatPtr& mat = *tensor.mat; + // Especially TSAN is slow enough to warrant hierarchical parallelism. + const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD + ? ParallelismStrategy::kHierarchical + : ParallelismStrategy::kFlat; + ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0, + [&](uint64_t task, size_t thread) { + PROFILER_ZONE3(ctx.profiler, thread, zone); + const TensorToRead& tensor = tensors[task]; + MatPtr& mat = *tensor.mat; - if (tensor.keep_type) { - HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes, - mat.Packed())); - return; - } + if (tensor.keep_type) { + HWY_ASSERT(reader.file().Read( + tensor.range.offset, tensor.range.bytes, mat.Packed())); + return; + } - // Read to a temporary buffer. - const hwy::AlignedFreeUniquePtr buf = - hwy::AllocateAligned(tensor.range.bytes); - HWY_ASSERT( - reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get())); + // Read to a temporary buffer. + const hwy::AlignedFreeUniquePtr buf = + hwy::AllocateAligned(tensor.range.bytes); + HWY_ASSERT(reader.file().Read(tensor.range.offset, + tensor.range.bytes, buf.get())); - if constexpr (GEMMA_ENABLE_NUQ) { - if (tensor.prev_type == Type::kNUQ) { - return DecompressToBF16(*tensor.mat, buf); - } - } - switch (tensor.prev_type) { - case Type::kF32: - return DecompressToBF16(*tensor.mat, buf); - case Type::kBF16: - return DecompressToBF16(*tensor.mat, buf); - case Type::kSFP: - return DecompressToBF16(*tensor.mat, buf); - default: - HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type)); - } - }); + if constexpr (GEMMA_ENABLE_NUQ) { + if (tensor.prev_type == Type::kNUQ) { + return DecompressToBF16(*tensor.mat, buf); + } + } + switch (tensor.prev_type) { + case Type::kF32: + return DecompressToBF16(*tensor.mat, buf); + case Type::kBF16: + return DecompressToBF16(*tensor.mat, buf); + case Type::kSFP: + return DecompressToBF16(*tensor.mat, buf); + default: + HWY_ABORT("Unsupported type %s", + TypeName(tensor.prev_type)); + } + }); } // Mode == kRead: diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 8b9c011..737feb6 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -770,11 +770,9 @@ class MMState { HWY_NOINLINE void DispatchParallelism(const StridedViewBF A, const MatPtrT& B, RowPtrs C_rows) const { - /* Disabled due to unknown thread-safety issue: static const auto zone = args_.env->ctx.profiler.AddZone("MM.DispatchParallelism"); PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone); - */ MMImpl::DispatchParallelism( args_.options.parallelism, @@ -1053,6 +1051,11 @@ template HWY_NOINLINE MMPerKey* MatMul(const MatPtrT& A, const MatPtrT& B, const float* HWY_RESTRICT add, MatMulEnv& env, MatPtrT& C, MMOptions options = MMOptions()) { + static const auto zone = env.ctx.profiler.AddZone("MM.MatMul"); + PROFILER_ZONE3(env.ctx.profiler, + options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(), + zone); + const Allocator& allocator = env.ctx.allocator; HWY_DASSERT(options.cluster_idx < env.row_ptrs.size()); MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx];