Expand ThreadingContext comments

PiperOrigin-RevId: 800479954
This commit is contained in:
Jan Wassenberg 2025-08-28 08:31:25 -07:00 committed by Copybara-Service
parent 6128e758ff
commit 98ddc166db
3 changed files with 33 additions and 9 deletions

View File

@ -227,12 +227,12 @@ struct TimingInfo {
}; };
// After construction, all methods are const and thread-compatible if using // After construction, all methods are const and thread-compatible if using
// separate ThreadingContext for each thread. // separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
class Gemma { class Gemma {
public: public:
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
// `ctx` is only used to read tensors, but it is typically also referenced // `ctx` is only used to read tensors and not stored. Calls to `Generate*`
// by the `MatMulEnv` passed to the Generate* methods. // may reference the same, or other `ThreadingContext` via `MatMulEnv`.
Gemma(const LoaderArgs& loader, const InferenceArgs& inference, Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
ThreadingContext& ctx); ThreadingContext& ctx);
~Gemma(); ~Gemma();
@ -248,6 +248,8 @@ class Gemma {
// `pos` is the position in the KV cache. Users are responsible for // `pos` is the position in the KV cache. Users are responsible for
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn. // incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
// All `Generate*` may be called concurrently if `env` and the
// `ThreadingContext` it references are both distinct.
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t pos, KVCache& kv_cache, MatMulEnv& env, size_t pos, KVCache& kv_cache, MatMulEnv& env,
TimingInfo& timing_info) const { TimingInfo& timing_info) const {

View File

@ -68,7 +68,9 @@ class NestedPools {
// `max_threads` is the maximum number of threads to divide among all // `max_threads` is the maximum number of threads to divide among all
// clusters. This is more intuitive than a per-cluster limit for users who // clusters. This is more intuitive than a per-cluster limit for users who
// may not be aware of the CPU topology. 0 means no limit. // may not be aware of the CPU topology. This should be zero (meaning no
// further limits) if the caller has already set limits via `skip_*` or
// `max_*` args passed to `ThreadingContext`.
// //
// To ensure we do not create more threads than there are HW cores, which // To ensure we do not create more threads than there are HW cores, which
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments // would cause huge slowdowns when spinning, the `BoundedSlice` arguments

View File

@ -28,6 +28,7 @@
#include "util/basics.h" // Tristate, kMaxPackages #include "util/basics.h" // Tristate, kMaxPackages
#include "util/threading.h" #include "util/threading.h"
#include "util/topology.h" #include "util/topology.h"
#include "hwy/profiler.h"
// IWYU pragma: end_exports // IWYU pragma: end_exports
namespace gcpp { namespace gcpp {
@ -55,9 +56,10 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
template <class Visitor> template <class Visitor>
void ForEach(const Visitor& visitor) { void ForEach(const Visitor& visitor) {
// These can be used to partition CPU sockets/packages and their // These can be used to partition CPU packages/sockets and their
// clusters/CCXs across several program instances. The default is to use // clusters/CCXs across several program instances. The default is to use
// all available resources. // all available resources on one package. Note that `kMaxPackages` is an
// upper bound on `max_packages`.
visitor(skip_packages, "skip_packages", size_t{0}, visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2); "Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{1}, visitor(max_packages, "max_packages", size_t{1},
@ -67,15 +69,18 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
"Index of the first CCX to use; default 0 = unlimited.", 2); "Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0}, visitor(max_clusters, "max_clusters", size_t{0},
"Max CCXs to use; default 0 = unlimited.", 2); "Max CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown. // "Logical processors" (LPs). These are used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0}, visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2); "Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0}, visitor(max_lps, "max_lps", size_t{0},
"Max LPs to use; default 0 = unlimited.", 2); "Max LPs to use; default 0 = unlimited.", 2);
// The exact meaning is more subtle: see the comment at NestedPools ctor. // DEPRECATED: superseded by the above fields. If nonzero, `NestedPools`
// will attempt to create this many threads distributed over the detected
// topology.
visitor(max_threads, "num_threads", size_t{0}, visitor(max_threads, "num_threads", size_t{0},
"Max threads to use; default 0 = unlimited.", 2); "Max threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", Tristate::kDefault, visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault, visitor(spin, "spin", Tristate::kDefault,
@ -86,13 +91,28 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
} }
}; };
// Owns threads corresponding to a subset of the system's resources. Because
// this is passed to `Gemma::Generate` (via `MatMulEnv`) rather than defined as
// a singleton, we can have multiple concurrent `Generate` calls within the
// same process, each with their own `ThreadingContext`. Because each context
// may pin its threads, it is important that they use distinct packages,
// clusters, or LPs. For example, to use two packages, the first `args` can have
// `skip_packages` = 0 and the second `skip_packages` = 1.
struct ThreadingContext { struct ThreadingContext {
// Expected to be called early in the program, before threading starts.
explicit ThreadingContext(const ThreadingArgs& args); explicit ThreadingContext(const ThreadingArgs& args);
// Singleton; pass around a reference to reduce overhead.
hwy::Profiler& profiler; hwy::Profiler& profiler;
// Detects topology, subject to limits imposed by user-specified `args`.
// For example, if `args.max_packages` is 1, then `topology.NumPackages()`
// will be 1 regardless of the actual system topology.
BoundedTopology topology; BoundedTopology topology;
// Ctor depends on `topology` for deciding whether to enable NUMA.
Allocator allocator; Allocator allocator;
// Per-package/cluster/within cluster pools of threads, matching `topology`.
NestedPools pools; NestedPools pools;
}; };