mirror of https://github.com/google/gemma.cpp.git
Faster startup on tsan: use hierarchical parallelism for BF16 conversion
Also re-enable profiler zones PiperOrigin-RevId: 804273899
This commit is contained in:
parent
cbe24eac51
commit
6e52a835c6
|
|
@ -383,39 +383,45 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
||||||
const BlobReader& reader, ThreadingContext& ctx) {
|
const BlobReader& reader, ThreadingContext& ctx) {
|
||||||
static const auto zone =
|
static const auto zone =
|
||||||
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
||||||
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
||||||
const TensorToRead& tensor = tensors[task];
|
? ParallelismStrategy::kHierarchical
|
||||||
MatPtr& mat = *tensor.mat;
|
: ParallelismStrategy::kFlat;
|
||||||
|
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||||
|
[&](uint64_t task, size_t thread) {
|
||||||
|
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||||
|
const TensorToRead& tensor = tensors[task];
|
||||||
|
MatPtr& mat = *tensor.mat;
|
||||||
|
|
||||||
if (tensor.keep_type) {
|
if (tensor.keep_type) {
|
||||||
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
|
HWY_ASSERT(reader.file().Read(
|
||||||
mat.Packed()));
|
tensor.range.offset, tensor.range.bytes, mat.Packed()));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Read to a temporary buffer.
|
// Read to a temporary buffer.
|
||||||
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
|
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
|
||||||
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
|
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(reader.file().Read(tensor.range.offset,
|
||||||
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
|
tensor.range.bytes, buf.get()));
|
||||||
|
|
||||||
if constexpr (GEMMA_ENABLE_NUQ) {
|
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||||
if (tensor.prev_type == Type::kNUQ) {
|
if (tensor.prev_type == Type::kNUQ) {
|
||||||
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
|
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
switch (tensor.prev_type) {
|
switch (tensor.prev_type) {
|
||||||
case Type::kF32:
|
case Type::kF32:
|
||||||
return DecompressToBF16<float>(*tensor.mat, buf);
|
return DecompressToBF16<float>(*tensor.mat, buf);
|
||||||
case Type::kBF16:
|
case Type::kBF16:
|
||||||
return DecompressToBF16<BF16>(*tensor.mat, buf);
|
return DecompressToBF16<BF16>(*tensor.mat, buf);
|
||||||
case Type::kSFP:
|
case Type::kSFP:
|
||||||
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
|
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
|
||||||
default:
|
default:
|
||||||
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
|
HWY_ABORT("Unsupported type %s",
|
||||||
}
|
TypeName(tensor.prev_type));
|
||||||
});
|
}
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Mode == kRead:
|
// Mode == kRead:
|
||||||
|
|
|
||||||
|
|
@ -770,11 +770,9 @@ class MMState {
|
||||||
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
|
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
|
||||||
const MatPtrT<TB>& B,
|
const MatPtrT<TB>& B,
|
||||||
RowPtrs<TC> C_rows) const {
|
RowPtrs<TC> C_rows) const {
|
||||||
/* Disabled due to unknown thread-safety issue:
|
|
||||||
static const auto zone =
|
static const auto zone =
|
||||||
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
|
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
|
||||||
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
|
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
|
||||||
*/
|
|
||||||
|
|
||||||
MMImpl::DispatchParallelism(
|
MMImpl::DispatchParallelism(
|
||||||
args_.options.parallelism,
|
args_.options.parallelism,
|
||||||
|
|
@ -1053,6 +1051,11 @@ template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||||
|
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
||||||
|
PROFILER_ZONE3(env.ctx.profiler,
|
||||||
|
options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(),
|
||||||
|
zone);
|
||||||
|
|
||||||
const Allocator& allocator = env.ctx.allocator;
|
const Allocator& allocator = env.ctx.allocator;
|
||||||
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
|
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
|
||||||
MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx];
|
MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx];
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue