mirror of https://github.com/google/gemma.cpp.git
Faster startup on tsan: use hierarchical parallelism for BF16 conversion
Also re-enable profiler zones PiperOrigin-RevId: 804273899
This commit is contained in:
parent
cbe24eac51
commit
6e52a835c6
|
|
@ -383,39 +383,45 @@ static void ReadAllToBF16(const std::vector<TensorToRead>& tensors,
|
|||
const BlobReader& reader, ThreadingContext& ctx) {
|
||||
static const auto zone =
|
||||
ctx.profiler.AddZone("Startup.Weights.ReadAllToBF16");
|
||||
ctx.pools.Pool().Run(0, tensors.size(), [&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
// Especially TSAN is slow enough to warrant hierarchical parallelism.
|
||||
const ParallelismStrategy strategy = HWY_IS_DEBUG_BUILD
|
||||
? ParallelismStrategy::kHierarchical
|
||||
: ParallelismStrategy::kFlat;
|
||||
ParallelFor(strategy, tensors.size(), ctx, /*cluster_idx=*/0,
|
||||
[&](uint64_t task, size_t thread) {
|
||||
PROFILER_ZONE3(ctx.profiler, thread, zone);
|
||||
const TensorToRead& tensor = tensors[task];
|
||||
MatPtr& mat = *tensor.mat;
|
||||
|
||||
if (tensor.keep_type) {
|
||||
HWY_ASSERT(reader.file().Read(tensor.range.offset, tensor.range.bytes,
|
||||
mat.Packed()));
|
||||
return;
|
||||
}
|
||||
if (tensor.keep_type) {
|
||||
HWY_ASSERT(reader.file().Read(
|
||||
tensor.range.offset, tensor.range.bytes, mat.Packed()));
|
||||
return;
|
||||
}
|
||||
|
||||
// Read to a temporary buffer.
|
||||
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
|
||||
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
|
||||
HWY_ASSERT(
|
||||
reader.file().Read(tensor.range.offset, tensor.range.bytes, buf.get()));
|
||||
// Read to a temporary buffer.
|
||||
const hwy::AlignedFreeUniquePtr<uint8_t[]> buf =
|
||||
hwy::AllocateAligned<uint8_t>(tensor.range.bytes);
|
||||
HWY_ASSERT(reader.file().Read(tensor.range.offset,
|
||||
tensor.range.bytes, buf.get()));
|
||||
|
||||
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||
if (tensor.prev_type == Type::kNUQ) {
|
||||
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
|
||||
}
|
||||
}
|
||||
switch (tensor.prev_type) {
|
||||
case Type::kF32:
|
||||
return DecompressToBF16<float>(*tensor.mat, buf);
|
||||
case Type::kBF16:
|
||||
return DecompressToBF16<BF16>(*tensor.mat, buf);
|
||||
case Type::kSFP:
|
||||
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
|
||||
default:
|
||||
HWY_ABORT("Unsupported type %s", TypeName(tensor.prev_type));
|
||||
}
|
||||
});
|
||||
if constexpr (GEMMA_ENABLE_NUQ) {
|
||||
if (tensor.prev_type == Type::kNUQ) {
|
||||
return DecompressToBF16<NuqStream>(*tensor.mat, buf);
|
||||
}
|
||||
}
|
||||
switch (tensor.prev_type) {
|
||||
case Type::kF32:
|
||||
return DecompressToBF16<float>(*tensor.mat, buf);
|
||||
case Type::kBF16:
|
||||
return DecompressToBF16<BF16>(*tensor.mat, buf);
|
||||
case Type::kSFP:
|
||||
return DecompressToBF16<SfpStream>(*tensor.mat, buf);
|
||||
default:
|
||||
HWY_ABORT("Unsupported type %s",
|
||||
TypeName(tensor.prev_type));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Mode == kRead:
|
||||
|
|
|
|||
|
|
@ -770,11 +770,9 @@ class MMState {
|
|||
HWY_NOINLINE void DispatchParallelism(const StridedViewBF A,
|
||||
const MatPtrT<TB>& B,
|
||||
RowPtrs<TC> C_rows) const {
|
||||
/* Disabled due to unknown thread-safety issue:
|
||||
static const auto zone =
|
||||
args_.env->ctx.profiler.AddZone("MM.DispatchParallelism");
|
||||
PROFILER_ZONE3(args_.env->ctx.profiler, MMImpl::Worker(args_), zone);
|
||||
*/
|
||||
|
||||
MMImpl::DispatchParallelism(
|
||||
args_.options.parallelism,
|
||||
|
|
@ -1053,6 +1051,11 @@ template <typename TA, typename TB, typename TC>
|
|||
HWY_NOINLINE MMPerKey* MatMul(const MatPtrT<TA>& A, const MatPtrT<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
MatPtrT<TC>& C, MMOptions options = MMOptions()) {
|
||||
static const auto zone = env.ctx.profiler.AddZone("MM.MatMul");
|
||||
PROFILER_ZONE3(env.ctx.profiler,
|
||||
options.cluster_idx * env.ctx.pools.MaxWorkersPerCluster(),
|
||||
zone);
|
||||
|
||||
const Allocator& allocator = env.ctx.allocator;
|
||||
HWY_DASSERT(options.cluster_idx < env.row_ptrs.size());
|
||||
MatMulEnv::PerCluster& per_cluster = env.per_cluster[options.cluster_idx];
|
||||
|
|
|
|||
Loading…
Reference in New Issue