mirror of https://github.com/google/gemma.cpp.git
Expand ThreadingContext comments
PiperOrigin-RevId: 800479954
This commit is contained in:
parent
6128e758ff
commit
98ddc166db
|
|
@ -227,12 +227,12 @@ struct TimingInfo {
|
|||
};
|
||||
|
||||
// After construction, all methods are const and thread-compatible if using
|
||||
// separate ThreadingContext for each thread.
|
||||
// separate `ThreadingContext` and `MatMulEnv` for each concurrent `Generate`.
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads weights/config/tokenizer from the `BlobStore` at `loader.weights`.
|
||||
// `ctx` is only used to read tensors, but it is typically also referenced
|
||||
// by the `MatMulEnv` passed to the Generate* methods.
|
||||
// `ctx` is only used to read tensors and not stored. Calls to `Generate*`
|
||||
// may reference the same, or other `ThreadingContext` via `MatMulEnv`.
|
||||
Gemma(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
ThreadingContext& ctx);
|
||||
~Gemma();
|
||||
|
|
@ -248,6 +248,8 @@ class Gemma {
|
|||
|
||||
// `pos` is the position in the KV cache. Users are responsible for
|
||||
// incrementing it in the `*StreamFunc`, or setting to zero for single-turn.
|
||||
// All `Generate*` may be called concurrently if `env` and the
|
||||
// `ThreadingContext` it references are both distinct.
|
||||
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
|
||||
size_t pos, KVCache& kv_cache, MatMulEnv& env,
|
||||
TimingInfo& timing_info) const {
|
||||
|
|
|
|||
|
|
@ -68,7 +68,9 @@ class NestedPools {
|
|||
|
||||
// `max_threads` is the maximum number of threads to divide among all
|
||||
// clusters. This is more intuitive than a per-cluster limit for users who
|
||||
// may not be aware of the CPU topology. 0 means no limit.
|
||||
// may not be aware of the CPU topology. This should be zero (meaning no
|
||||
// further limits) if the caller has already set limits via `skip_*` or
|
||||
// `max_*` args passed to `ThreadingContext`.
|
||||
//
|
||||
// To ensure we do not create more threads than there are HW cores, which
|
||||
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "util/basics.h" // Tristate, kMaxPackages
|
||||
#include "util/threading.h"
|
||||
#include "util/topology.h"
|
||||
#include "hwy/profiler.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -55,9 +56,10 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
// These can be used to partition CPU sockets/packages and their
|
||||
// These can be used to partition CPU packages/sockets and their
|
||||
// clusters/CCXs across several program instances. The default is to use
|
||||
// all available resources.
|
||||
// all available resources on one package. Note that `kMaxPackages` is an
|
||||
// upper bound on `max_packages`.
|
||||
visitor(skip_packages, "skip_packages", size_t{0},
|
||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||
visitor(max_packages, "max_packages", size_t{1},
|
||||
|
|
@ -67,15 +69,18 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Max CCXs to use; default 0 = unlimited.", 2);
|
||||
// These are only used when CPU topology is unknown.
|
||||
// "Logical processors" (LPs). These are used when CPU topology is unknown.
|
||||
visitor(skip_lps, "skip_lps", size_t{0},
|
||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
||||
visitor(max_lps, "max_lps", size_t{0},
|
||||
"Max LPs to use; default 0 = unlimited.", 2);
|
||||
|
||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||
// DEPRECATED: superseded by the above fields. If nonzero, `NestedPools`
|
||||
// will attempt to create this many threads distributed over the detected
|
||||
// topology.
|
||||
visitor(max_threads, "num_threads", size_t{0},
|
||||
"Max threads to use; default 0 = unlimited.", 2);
|
||||
|
||||
visitor(pin, "pin", Tristate::kDefault,
|
||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(spin, "spin", Tristate::kDefault,
|
||||
|
|
@ -86,13 +91,28 @@ class ThreadingArgs : public ArgsBase<ThreadingArgs> {
|
|||
}
|
||||
};
|
||||
|
||||
// Owns threads corresponding to a subset of the system's resources. Because
|
||||
// this is passed to `Gemma::Generate` (via `MatMulEnv`) rather than defined as
|
||||
// a singleton, we can have multiple concurrent `Generate` calls within the
|
||||
// same process, each with their own `ThreadingContext`. Because each context
|
||||
// may pin its threads, it is important that they use distinct packages,
|
||||
// clusters, or LPs. For example, to use two packages, the first `args` can have
|
||||
// `skip_packages` = 0 and the second `skip_packages` = 1.
|
||||
struct ThreadingContext {
|
||||
// Expected to be called early in the program, before threading starts.
|
||||
explicit ThreadingContext(const ThreadingArgs& args);
|
||||
|
||||
// Singleton; pass around a reference to reduce overhead.
|
||||
hwy::Profiler& profiler;
|
||||
|
||||
// Detects topology, subject to limits imposed by user-specified `args`.
|
||||
// For example, if `args.max_packages` is 1, then `topology.NumPackages()`
|
||||
// will be 1 regardless of the actual system topology.
|
||||
BoundedTopology topology;
|
||||
|
||||
// Ctor depends on `topology` for deciding whether to enable NUMA.
|
||||
Allocator allocator;
|
||||
|
||||
// Per-package/cluster/within cluster pools of threads, matching `topology`.
|
||||
NestedPools pools;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue