diff --git a/gemma/gemma.h b/gemma/gemma.h index 55be003..0f9aae2 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -227,12 +227,12 @@ struct TimingInfo { }; // After construction, all methods are const and thread-compatible if using -// separate ThreadingContext for each thread. +// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`. class Gemma { public: // Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`. - // `ctx` is only used to read tensors, but it is typically also referenced - // by the `MatMulEnv` passed to the Generate* methods. + // `ctx` is only used to read tensors and not stored. Calls to `Generate*` + // may reference the same, or other `ThreadingContext` via `MatMulEnv`. Gemma(const LoaderArgs& loader, const InferenceArgs& inference, ThreadingContext& ctx); ~Gemma(); @@ -248,6 +248,8 @@ class Gemma { // `pos` is the position in the KV cache. Users are responsible for // incrementing it in the `*StreamFunc`, or setting to zero for single-turn. + // All `Generate*` may be called concurrently if `env` and the + // `ThreadingContext` it references are both distinct. void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, KVCache& kv_cache, MatMulEnv& env, TimingInfo& timing_info) const { diff --git a/util/threading.h b/util/threading.h index 6c2e187..0a57ddb 100644 --- a/util/threading.h +++ b/util/threading.h @@ -68,7 +68,9 @@ class NestedPools { // `max_threads` is the maximum number of threads to divide among all // clusters. This is more intuitive than a per-cluster limit for users who - // may not be aware of the CPU topology. 0 means no limit. + // may not be aware of the CPU topology. This should be zero (meaning no + // further limits) if the caller has already set limits via `skip_*` or + // `max_*` args passed to `ThreadingContext`. // // To ensure we do not create more threads than there are HW cores, which // would cause huge slowdowns when spinning, the `BoundedSlice` arguments diff --git a/util/threading_context.h b/util/threading_context.h index 08387d0..d4fdc17 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -28,6 +28,7 @@ #include "util/basics.h" // Tristate, kMaxPackages #include "util/threading.h" #include "util/topology.h" +#include "hwy/profiler.h" // IWYU pragma: end_exports namespace gcpp { @@ -55,9 +56,10 @@ class ThreadingArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { - // These can be used to partition CPU sockets/packages and their + // These can be used to partition CPU packages/sockets and their // clusters/CCXs across several program instances. The default is to use - // all available resources. + // all available resources on one package. Note that `kMaxPackages` is an + // upper bound on `max_packages`. visitor(skip_packages, "skip_packages", size_t{0}, "Index of the first socket to use; default 0 = unlimited.", 2); visitor(max_packages, "max_packages", size_t{1}, @@ -67,15 +69,18 @@ class ThreadingArgs : public ArgsBase { "Index of the first CCX to use; default 0 = unlimited.", 2); visitor(max_clusters, "max_clusters", size_t{0}, "Max CCXs to use; default 0 = unlimited.", 2); - // These are only used when CPU topology is unknown. + // "Logical processors" (LPs). These are used when CPU topology is unknown. visitor(skip_lps, "skip_lps", size_t{0}, "Index of the first LP to use; default 0 = unlimited.", 2); visitor(max_lps, "max_lps", size_t{0}, "Max LPs to use; default 0 = unlimited.", 2); - // The exact meaning is more subtle: see the comment at NestedPools ctor. + // DEPRECATED: superseded by the above fields. If nonzero, `NestedPools` + // will attempt to create this many threads distributed over the detected + // topology. visitor(max_threads, "num_threads", size_t{0}, "Max threads to use; default 0 = unlimited.", 2); + visitor(pin, "pin", Tristate::kDefault, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); visitor(spin, "spin", Tristate::kDefault, @@ -86,13 +91,28 @@ class ThreadingArgs : public ArgsBase { } }; +// Owns threads corresponding to a subset of the system's resources. Because +// this is passed to `Gemma::Generate` (via `MatMulEnv`) rather than defined as +// a singleton, we can have multiple concurrent `Generate` calls within the +// same process, each with their own `ThreadingContext`. Because each context +// may pin its threads, it is important that they use distinct packages, +// clusters, or LPs. For example, to use two packages, the first `args` can have +// `skip_packages` = 0 and the second `skip_packages` = 1. struct ThreadingContext { - // Expected to be called early in the program, before threading starts. explicit ThreadingContext(const ThreadingArgs& args); + // Singleton; pass around a reference to reduce overhead. hwy::Profiler& profiler; + + // Detects topology, subject to limits imposed by user-specified `args`. + // For example, if `args.max_packages` is 1, then `topology.NumPackages()` + // will be 1 regardless of the actual system topology. BoundedTopology topology; + + // Ctor depends on `topology` for deciding whether to enable NUMA. Allocator allocator; + + // Per-package/cluster/within cluster pools of threads, matching `topology`. NestedPools pools; };