Faster startup on tsan: use hierarchical parallelism for BF16 conversion

Also re-enable profiler zones

PiperOrigin-RevId: 804273899
This commit is contained in:
Jan Wassenberg 2025-09-07 22:50:01 -07:00 committed by Copybara-Service
parent cbe24eac51
commit 6e52a835c6
2 changed files with 41 additions and 32 deletions

View File

@ -383,22 +383,27 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
const BlobReader& reader, ThreadingContext& ctx) {
static const auto zone =
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
// Especially TSAN is slow enough to warrant hierarchical parallelism.
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
? ParallelismStrategy::kHierarchical
: ParallelismStrategy::kFlat;
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
[&](uint64_t task, size_t thread) {
PROFILER_ZONE3(ctx.profiler, thread, zone);
const TensorToRead& tensor = tensors[task];
MatPtr& mat = *tensor.mat;
if (tensor.keep_type) {
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
mat.Packed()));
HWY_ASSERT(reader.file().Read(
tensor.range.offset, tensor.range.bytes, mat.Packed()));
return;
}
// Read to a temporary buffer.
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
HWY_ASSERT(
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
HWY_ASSERT(reader.file().Read(tensor.range.offset,
tensor.range.bytes, buf.get()));
if constexpr (GEMMA_ENABLE_NUQ) {
if (tensor.prev_type == Type::kNUQ) {
@ -413,7 +418,8 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
case Type::kSFP:
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
default:
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
HWY_ABORT("Unsupported type %s",
TypeName(tensor.prev_type));
}
});
}

View File

@ -770,11 +770,9 @@ class MMState {
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
const MatPtrT<TB>& B,
RowPtrs<TC> C_rows) const {
/* Disabled due to unknown thread-safety issue:
static const auto zone =
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
*/
MMImpl::DispatchParallelism(
args_.options.parallelism,
@ -1053,6 +1051,11 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
PROFILER_ZONE3(env.ctx.profiler,
options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(),
zone);
const Allocator& allocator = env.ctx.allocator;
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx];