mirror of https://github.com/google/gemma.cpp.git
110 lines
3.8 KiB
C++
110 lines
3.8 KiB
C++
// Copyright 2025 Google LLC
|
|
// SPDX-License-Identifier: Apache-2.0
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// https://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
#include "util/threading_context.h"
|
|
|
|
#include <stddef.h>
|
|
#include <stdint.h>
|
|
|
|
#include <vector>
|
|
|
|
#include "hwy/aligned_allocator.h"
|
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
#include "hwy/profiler.h"
|
|
#include "hwy/tests/test_util.h" // RandomState
|
|
|
|
namespace gcpp {
|
|
|
|
// Invokes `pool.Run` with varying task counts until auto-tuning completes, or
|
|
// an upper bound just in case.
|
|
static void TunePool(hwy::PoolWaitMode wait_mode, hwy::ThreadPool& pool) {
|
|
pool.SetWaitMode(wait_mode);
|
|
|
|
// TODO(janwas): re-enable after investigating potential deadlock.
|
|
#if 0
|
|
const size_t num_workers = pool.NumWorkers();
|
|
// pool.Run would just be a serial loop without auto-tuning, so skip.
|
|
if (num_workers == 1) return;
|
|
|
|
// Random shuffle of task counts to defeat branch prediction.
|
|
const size_t num_tasks[4] = {HWY_MAX(1, num_workers / 2), num_workers * 1,
|
|
num_workers * 5, num_workers * 20};
|
|
|
|
// Count tasks executed to ensure workers aren't optimized out. One per
|
|
// cache line to avoid false sharing.
|
|
const size_t kSizePerLine = HWY_ALIGNMENT / sizeof(size_t);
|
|
|
|
std::vector<size_t> counters(num_workers * kSizePerLine);
|
|
size_t prev_total = 0; // avoids having to reset counters.
|
|
|
|
hwy::RandomState rng;
|
|
for (size_t rep = 0; rep < 500; ++rep) {
|
|
if (HWY_UNLIKELY(pool.AutoTuneComplete())) {
|
|
break;
|
|
}
|
|
|
|
const uint64_t r = hwy::Random64(&rng);
|
|
const size_t begin = r >> 2;
|
|
const size_t end = begin + num_tasks[r & 3];
|
|
|
|
pool.Run(begin, end, [&](uint64_t task, size_t thread) {
|
|
HWY_ASSERT(begin <= task && task < end);
|
|
HWY_ASSERT(thread < num_workers);
|
|
counters[thread * kSizePerLine]++;
|
|
});
|
|
|
|
// Reduce count and ensure it matches the expected number of tasks.
|
|
size_t total = 0;
|
|
for (size_t i = 0; i < num_workers; ++i) {
|
|
total += counters[i * kSizePerLine];
|
|
}
|
|
const size_t expected = end - begin;
|
|
HWY_ASSERT(total == prev_total + expected);
|
|
prev_total += expected;
|
|
}
|
|
#endif
|
|
}
|
|
|
|
static void TunePools(hwy::PoolWaitMode wait_mode, NestedPools& pools) {
|
|
TunePool(wait_mode, pools.AllPackages());
|
|
for (size_t pkg_idx = 0; pkg_idx < pools.NumPackages(); ++pkg_idx) {
|
|
hwy::ThreadPool& clusters = pools.AllClusters(pkg_idx);
|
|
TunePool(wait_mode, clusters);
|
|
|
|
// Run in parallel because Turin CPUs have 16, and in real usage, we often
|
|
// run all at the same time.
|
|
clusters.Run(0, clusters.NumWorkers(),
|
|
[&](uint64_t cluster_idx, size_t /*thread*/) {
|
|
TunePool(wait_mode, pools.Cluster(pkg_idx, cluster_idx));
|
|
});
|
|
}
|
|
}
|
|
|
|
ThreadingContext::ThreadingContext(const ThreadingArgs& args)
|
|
: profiler(hwy::Profiler::Get()),
|
|
topology(BoundedSlice(args.skip_packages, args.max_packages),
|
|
BoundedSlice(args.skip_clusters, args.max_clusters),
|
|
BoundedSlice(args.skip_lps, args.max_lps)),
|
|
cache_info(topology),
|
|
allocator(topology, cache_info, args.bind != Tristate::kFalse),
|
|
pools(topology, allocator, args.max_threads, args.pin) {
|
|
PROFILER_ZONE("Startup.ThreadingContext autotune");
|
|
TunePools(hwy::PoolWaitMode::kSpin, pools);
|
|
// kBlock is the default, hence set/tune it last.
|
|
TunePools(hwy::PoolWaitMode::kBlock, pools);
|
|
}
|
|
|
|
} // namespace gcpp
|