Commit Graph

169 Commits

Author SHA1 Message Date
Krzysztof Rymski 44dfd69b9b Internal changes
PiperOrigin-RevId: 844759322
2025-12-15 07:14:37 -08:00
Jan Wassenberg 0c64987a96 Abort if args are unrecognized, refactor argument passing
This catches typos/incorrect usage.
Refactor: group Loader/Threading/Inference into GemmaArgs.
All *Args ctors now have an extra ConsumedArgs& argument.
PiperOrigin-RevId: 844690553
2025-12-15 03:18:45 -08:00
Jan Wassenberg 73c3627b67 Add tensor stats and output
tensor_info: add missing header
io: fix mode
weights.h: add layer_idx to LayerWeightsPtrs
PiperOrigin-RevId: 843531051
2025-12-11 22:52:46 -08:00
Martin Stolle bfc0dfcfca Enable flags= parsing
PiperOrigin-RevId: 843103750
2025-12-11 01:17:59 -08:00
Krzysztof Rymski 64178ace38 Internal changes
PiperOrigin-RevId: 842727112
2025-12-10 07:55:17 -08:00
Jan Wassenberg 5a6895c609 Avoid warning when OS affinity limits us to the second socket
Also simplify NumSMT, detect from .smt field directly

PiperOrigin-RevId: 841749486
2025-12-08 07:10:43 -08:00
Jan Wassenberg 1564dd3111 Fix empty enabled_lps in topology detection
Also expand the debug output.

PiperOrigin-RevId: 838832605
2025-12-01 10:23:47 -08:00
Jan Wassenberg 3c9e6cf113 Expand debug output for topology
PiperOrigin-RevId: 837738553
2025-11-28 00:19:33 -08:00
Jan Wassenberg ccb49bc82f Add ToFloatSlow, move RandomFloat to test_util
PiperOrigin-RevId: 837412290
2025-11-27 00:14:51 -08:00
Jan Wassenberg 091b4567c9 Minor: ParallelismStrategy->Parallelism
PiperOrigin-RevId: 828936578
2025-11-06 06:56:10 -08:00
Jan Wassenberg a344a70c59 Change (old) attention behavior to disallow wraparound, enforced via assertion.
Shared kU64PerLine constant

PiperOrigin-RevId: 828072451
2025-11-04 11:52:40 -08:00
Jan Wassenberg 3cc0139ebb Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage.

Also do not require specific type for query_norm_scale,
update batch sizes for attention tensors,
more verbose Mat shape/type checks.

PiperOrigin-RevId: 824987689
2025-10-28 05:33:01 -07:00
Biruk Mammo 5a05857deb [Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`.
* Updates `QBatch` to hold  non-owning `MatPtr`s to the kv caches.
* Enables the `MatPtrT` default constructor for simpler initializations.
* Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor.

PiperOrigin-RevId: 824584177
2025-10-27 10:43:25 -07:00
Jan Wassenberg 86200ce224 1.01x speedup: improved autotune
Group M=4..7 into same config. Add configs for power of two sizes.
Allow odd mc to enable a single range for odd M.

io.cc: warning fix(cast).
IsBlock -> !IsOneMC
benchmark_helper: best for verbosity 3, all configs for 4
ops_test: remove unused includes
PiperOrigin-RevId: 824475104
2025-10-27 05:35:31 -07:00
Jan Wassenberg a48e614f64 1.02x speedup: improve load balance and simplify parallelFor
Remove ParallelizeOne/TwoRange, use ParallelForAcross/WithinCluster instead.

PiperOrigin-RevId: 823388890
2025-10-24 00:19:09 -07:00
Jan Wassenberg 3ed403e287 Major cleanup of profiler zones, add Caller annotation for all pool.Run
Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones
Add GCPP_ZONE helper
Add Caller argument to pool.Run to enable new stats
Remove most direct dependencies on ThreadPool, prefer ParallelFor

PiperOrigin-RevId: 822934530
2025-10-23 01:54:24 -07:00
Jan Wassenberg acede9d682 Warning fix (unused var), Windows build fix (missing member variable)
PiperOrigin-RevId: 822172982
2025-10-21 10:17:34 -07:00
Jan Wassenberg f59eb2ed72 Remove multi-package support from topology
Also no longer assume equal-sized clusters

PiperOrigin-RevId: 820164125
2025-10-16 04:00:35 -07:00
Phil Culliton 503aaddd65 Add 8-bit integer quantization (I8Stream) to Gemma.cpp.
PiperOrigin-RevId: 819787856
2025-10-15 09:25:20 -07:00
Ray Smith e3e8511e79 Initialization of profiler zones.
PiperOrigin-RevId: 819662587
2025-10-15 03:05:58 -07:00
Ray Smith fb6fa793f4 Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle.
Improved flash_attention to enable profiling using the new zones.

PiperOrigin-RevId: 819235421
2025-10-14 08:30:58 -07:00
Jan Wassenberg 035273c184 tune pool kSpin mode in threading_context
Previously, this happened concurrently with the matmul autotune, which could lead to incorrect outcomes.

threading: de-singleton Pinning (no longer stores affinity); pass PoolWorkerMapping; fix Pool dtor order
Also enable SPR target (Zen4 is AMD-only),
update Highway version for renamed Thread()->GlobalIdx().
PiperOrigin-RevId: 816223017
2025-10-07 08:36:26 -07:00
Jan Wassenberg f3bc1c17da 1.03x speedup: fused FFN
matmul-inl: support CView=StridedView or RowPtrs; rename to C_MC_NC
matmul.cc: Allow 1 more rep for MC/NC to allow half-sized tiles, which helps.
PiperOrigin-RevId: 807291701
2025-09-15 10:26:37 -07:00
Jan Wassenberg ba6131311a Fix gemma_batch_bench for flash attention
q_T rows do not change.
Also repeat prefill to reflect perf after autotuning.

PiperOrigin-RevId: 805319377
2025-09-10 05:32:34 -07:00
Jan Wassenberg 9457258330 Refactor MatMul to accept views in the kernel functions
Make arg order consistent.
Move StridedView into mat.h.
Add view support to RowPtrs.

PiperOrigin-RevId: 805197381
2025-09-09 22:09:47 -07:00
Jan Wassenberg 24b1760f03 Refactor: move Worker to ThreadingContext, factor out MMDecompress
PiperOrigin-RevId: 804909921
2025-09-09 07:56:12 -07:00
Jan Wassenberg 461a9c7d1b Matmul refactoring towards fusion
MMLoops: move dispatch code out, use overloads
split build target into matmul_env (for MatMulEnv/MMOptions)
weights: no longer call BindB
Fix potential out of bounds in gemma_batch_bench
PiperOrigin-RevId: 804895985
2025-09-09 07:13:38 -07:00
Jan Wassenberg a5ab99e4ba Memory use reduction: smaller/single MMStorage
PiperOrigin-RevId: 804865029
2025-09-09 05:32:46 -07:00
Jan Wassenberg 06e5da1e22 Cleanup: split CacheInfo from Allocator, MatMul helper functions
Lift DecompressA out of main autotuner to prevent interference
Also use kMaxNR / kNR constants instead of extra args
Fix: only require vector alignment, not cache alignment
PiperOrigin-RevId: 804333769
2025-09-08 02:23:58 -07:00
Jan Wassenberg 56186193c1 Replace mt19937 with new generator to enable parallel sampling
Split it into immutable AesCtrEngine and RngStream
Also add RowSpan and Logits span

PiperOrigin-RevId: 803336423
2025-09-04 23:49:10 -07:00
Jan Wassenberg afd82376a5 Add AES-CTR RNG for parallel sampling (not yet used)
PiperOrigin-RevId: 802991142
2025-09-04 05:58:42 -07:00
Jan Wassenberg 4be4799727 Remove kMaxPackages and per-package-related code
matmul: remove kMaxClusters, dynamic allocation
PiperOrigin-RevId: 802950348
2025-09-04 03:33:12 -07:00
Jan Wassenberg 7263ab8445 MatMul simplification, threading strategy improvements
remove MatMul f32 special case (smaller code),
types: Add u32/u64 for use by Activations
move renamed ParallelismStrategy to threading_context so can pass ctx
ensure worker index is unique across clusters
matmul.h: const member functions for renamed policy classes (easier to call)
PiperOrigin-RevId: 802848086
2025-09-03 21:45:07 -07:00
Jan Wassenberg b7b3d353db Simplify MatMul: remove F32 special case (build time)
Also move kMaxM into separate kMaxBatchSize

PiperOrigin-RevId: 802086590
2025-09-02 04:29:21 -07:00
Jan Wassenberg 1e3c853e80 Add ParallelFor wrapper function and one new mode
Move ParallelismType from matmul.h to threading.h
Replace SmallParallelFor with ParallelFor and the new mode

PiperOrigin-RevId: 802038452
2025-09-02 01:40:09 -07:00
Jan Wassenberg 98ddc166db Expand ThreadingContext comments
PiperOrigin-RevId: 800479954
2025-08-28 08:32:10 -07:00
Jan Wassenberg faa4102992 (Resubmit) Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 794461159
2025-08-13 01:38:24 -07:00
The gemma.cpp Authors a2d9133f7d Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 793865287
2025-08-11 17:51:38 -07:00
Jan Wassenberg 4cbf63e6f0 Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 793821255
2025-08-11 15:34:52 -07:00
Jan Wassenberg 701841897b Default to disabling per-socket parallelization
weights: default to Read for small-batch (only look at qbatch, not the larger prefill tbatch)
PiperOrigin-RevId: 790787643
2025-08-04 09:49:14 -07:00
Jan Wassenberg 799c264df3 Pre-tune thread pool before matmul
Also improve profiler annotations - remove near-zero ones and add more for startup

PiperOrigin-RevId: 789352414
2025-07-31 08:45:26 -07:00
Jan Wassenberg d1638587f0 1.14x batch decode speedup: parallelize RMSNorm ops
Activations was over-parallelized, use single pool instead.
Also improve profiler zone annotations,
pass through worker args (for tracking concurrency), now non-optional.

PiperOrigin-RevId: 788790976
2025-07-30 00:55:45 -07:00
Jan Wassenberg e76e29ce11 De-singleton ThreadingContext so callers can pass in their own
weights.cc: fix BindB argument for bf16 tensors
threading_test: enable autotune
PiperOrigin-RevId: 785763618
2025-07-22 02:08:46 -07:00
Jan Wassenberg fea9a07d9b Avoid affinity related warnings on Apple. Refs #625
PiperOrigin-RevId: 778895832
2025-07-03 08:22:31 -07:00
Jan Wassenberg 0f70f285e0 1.1x prefill and decode speedup (attention/activations)
Optimizations
- Better load-balancing in attention threading
(Previously, clusters were limited by #heads)
- Add MulByConstTo to avoid zero-init
- Parallel activations

Cleanup
- Prepare for RowPtr in A or B
- Pass through thread_id to ops
- Avoid warning in bench_matmul

PiperOrigin-RevId: 773723423
2025-06-20 08:59:53 -07:00
Jan Wassenberg f2adbfbcab Batch inference fixes: set pos during prefill, fix assert
PiperOrigin-RevId: 772458760
2025-06-17 07:09:44 -07:00
Jan Wassenberg bd98b43cea Rename RowPtr->StridedView, CRows->RowPtrs
PiperOrigin-RevId: 770046362
2025-06-11 02:30:53 -07:00
Jan Wassenberg 3a266c662c Split gemma-inl into separate source files
weights, mat: zero-initialize padding, required since the MatMul "avoid B decompress" optimization.

PiperOrigin-RevId: 767562313
2025-06-05 05:36:44 -07:00
Jan Wassenberg 9efdcfd45c 1.07x batch decode speedup: more BF16 weights and activations
BF16 att_sums and ffw_out
Support BF16 B views without decompression
Support arbitrary types in MulByConstAndAdd, AddFrom

Also update profiler annotations in ops-inl.h

PiperOrigin-RevId: 766995010
2025-06-03 23:30:18 -07:00
Jan Wassenberg 794a21a4e6 Major refactor to de-templatize gemma-inl and weights
This replaces per-weight instantiations of all code with only per-MatMul/norm.
Reduces binary size by 133KiB.

WeightsOwner is no longer required for type erasing, hence it is replaced with ModelWeightsPtrs.
Also remove unused EmbedToken, replaced with EmbedMMToken.

PiperOrigin-RevId: 766497657
2025-06-02 23:01:35 -07:00