Commit Graph

49 Commits

Author SHA1 Message Date
Jan Wassenberg 7263ab8445 MatMul simplification, threading strategy improvements
remove MatMul f32 special case (smaller code),
types: Add u32/u64 for use by Activations
move renamed ParallelismStrategy to threading_context so can pass ctx
ensure worker index is unique across clusters
matmul.h: const member functions for renamed policy classes (easier to call)
PiperOrigin-RevId: 802848086
2025-09-03 21:45:07 -07:00
Marie White 74ffe079c4 Create separate MMStorage objects per cluster.
PiperOrigin-RevId: 802588625
2025-09-03 09:35:48 -07:00
Jan Wassenberg b7b3d353db Simplify MatMul: remove F32 special case (build time)
Also move kMaxM into separate kMaxBatchSize

PiperOrigin-RevId: 802086590
2025-09-02 04:29:21 -07:00
Jan Wassenberg 1e3c853e80 Add ParallelFor wrapper function and one new mode
Move ParallelismType from matmul.h to threading.h
Replace SmallParallelFor with ParallelFor and the new mode

PiperOrigin-RevId: 802038452
2025-09-02 01:40:09 -07:00
Marie White 3737224132 Add in-cluster parallel policy. Update policy to include cluster_idx.
PiperOrigin-RevId: 802016308
2025-09-02 00:16:00 -07:00
Marie White 27cb8e12d9 Handle non-threading parallel policy.
PiperOrigin-RevId: 802012517
2025-09-02 00:02:57 -07:00
Marie White 0d2e74d74a Add MMOptions as an argument to Matmul.
PiperOrigin-RevId: 802008198
2025-09-01 23:46:39 -07:00
Marie White 00b70f69c5 Include parallelism type in DoMatMul. Also remove package handling.
PiperOrigin-RevId: 800902568
2025-08-29 08:04:52 -07:00
Jan Wassenberg 0ae8646731 Fix remainder handling for Paligemma
No longer attempt to skip the remainder handling because B might also be a non-padded view.

PiperOrigin-RevId: 800890805
2025-08-29 07:25:52 -07:00
Marie White 973e284ed6 Refactor Matmul to use a policy class for parallelization.
PiperOrigin-RevId: 800864489
2025-08-29 05:40:39 -07:00
Jan Wassenberg 7288891439 Remove F64 partial storage in matmul.
Also remove no longer used kMaxN; row_ptrs only used for C

PiperOrigin-RevId: 800774757
2025-08-29 00:12:08 -07:00
Jan Wassenberg 31c09cca4c f32 LoopKC: 1.37x(M=512), 1.19(M=128) single-K F32,BF16 matmul speedup on SKX
Add a special case for A=F32,B=BF16, used when there is no native bf16 dot product.

dot-inl: ensure bf16,f32 and f32,bf16 both get promoted to float before f64 summation
matmul.cc: update autotuning to reflect actual A size
matmul_test: add all combinations of bf16/f32, report all results, not just first difference, check non-vector-aligned K
PiperOrigin-RevId: 800487817
2025-08-28 08:55:50 -07:00
Jan Wassenberg 71406cf6d0 More profiler interface fixes: hwy:: plus avoid ADD_ZONE
PiperOrigin-RevId: 794493165
2025-08-13 03:15:48 -07:00
Jan Wassenberg eef564e8f0 Prepare profiler annotations for new API
PiperOrigin-RevId: 792808391
2025-08-08 16:51:29 -07:00
Jan Wassenberg 701841897b Default to disabling per-socket parallelization
weights: default to Read for small-batch (only look at qbatch, not the larger prefill tbatch)
PiperOrigin-RevId: 790787643
2025-08-04 09:49:14 -07:00
Jan Wassenberg d1638587f0 1.14x batch decode speedup: parallelize RMSNorm ops
Activations was over-parallelized, use single pool instead.
Also improve profiler zone annotations,
pass through worker args (for tracking concurrency), now non-optional.

PiperOrigin-RevId: 788790976
2025-07-30 00:55:45 -07:00
Jan Wassenberg 349c86f2d9 Fix bench_matmul perf regression: A input should be padded
PiperOrigin-RevId: 781976414
2025-07-11 07:36:36 -07:00
Jan Wassenberg 0f70f285e0 1.1x prefill and decode speedup (attention/activations)
Optimizations
- Better load-balancing in attention threading
(Previously, clusters were limited by #heads)
- Add MulByConstTo to avoid zero-init
- Parallel activations

Cleanup
- Prepare for RowPtr in A or B
- Pass through thread_id to ops
- Avoid warning in bench_matmul

PiperOrigin-RevId: 773723423
2025-06-20 08:59:53 -07:00
Jan Wassenberg 4f5785b0fd Update instrumentation for new Highway wall-time profiler
Pass the thread index through and use new zone_id.

PiperOrigin-RevId: 773344242
2025-06-19 07:46:04 -07:00
Jan Wassenberg 2c72ff2aa5 Fix MatMul issue caused by autotuning bucketing, refs #608, thanks @ufownl
PiperOrigin-RevId: 771077158
2025-06-13 06:58:42 -07:00
Jan Wassenberg bd98b43cea Rename RowPtr->StridedView, CRows->RowPtrs
PiperOrigin-RevId: 770046362
2025-06-11 02:30:53 -07:00
Jan Wassenberg 6ee628ba38 Further cleanup: separate MatMulEnv arg
move row_ptrs into MatMulEnv
Consistent arg order: layer, activations, kv_cache, env

PiperOrigin-RevId: 767886386
2025-06-05 20:48:32 -07:00
Jan Wassenberg 3a266c662c Split gemma-inl into separate source files
weights, mat: zero-initialize padding, required since the MatMul "avoid B decompress" optimization.

PiperOrigin-RevId: 767562313
2025-06-05 05:36:44 -07:00
Jan Wassenberg 6897313080 3x speedup of EmbedImagePatches - GEMM, not GEMV.
Required fixes to handling of non-vector aligned A.
Also move row ptrs to MatMulEnv.

PiperOrigin-RevId: 767029036
2025-06-04 01:18:52 -07:00
Jan Wassenberg 9efdcfd45c 1.07x batch decode speedup: more BF16 weights and activations
BF16 att_sums and ffw_out
Support BF16 B views without decompression
Support arbitrary types in MulByConstAndAdd, AddFrom

Also update profiler annotations in ops-inl.h

PiperOrigin-RevId: 766995010
2025-06-03 23:30:18 -07:00
Jan Wassenberg 839a642992 Fix paligemma_test, refs #588
Detect PaliGemma models from layer names
Remove unused allocator arg from CreateInvTimescale
matmul: only warn once about dim divisibility
Print config also in tests if --verbosity 2
PiperOrigin-RevId: 766605131
2025-06-03 04:45:22 -07:00
Jan Wassenberg cf4d7ceb82 1.16x decode speedup: remove last MatVec in Attention
Precompute row pointers.
Remove no longer used MHA support; QStride -> qkv_dim.
Remove RowPtr from MatMul interface, use only MatPtrT.
Require opt-in define for NUQ to speed up builds.
Also fix io.cc on Windows.

PiperOrigin-RevId: 766228108
2025-06-02 09:40:29 -07:00
Jan Wassenberg 0023ff8770 Add support for arbitrary output row pointers
Useful for writing directly to KV cache.

PiperOrigin-RevId: 765615147
2025-05-31 10:55:54 -07:00
Jan Wassenberg 38a08d8095 Replace last ConstMat with MatPtr
This is to reduce the number of MatMul overloads in preparation for de-templatizing.

PiperOrigin-RevId: 758288589
2025-05-13 10:55:22 -07:00
Jan Wassenberg 2038dfd9cc Minor: rename compression/shared -> types.h
PiperOrigin-RevId: 758199851
2025-05-13 06:53:21 -07:00
Jan Wassenberg d538a6d6c6 Cleanup: remove unused kCyclic, remove 2 suffix
Also remove now unused allocator arg and fix warnings (cast, struct/class mismatch)

PiperOrigin-RevId: 758098495
2025-05-13 01:06:41 -07:00
Jan Wassenberg 45ad847a41 Replace RowVectorBatch with MatStorageT
KVCache: add ctor required for MatStorageT, remove Create; bf_pre_ffw_rms_out -> pre_ffw_rms_out
optimize_test: larger vocab_size requires more steps
shared.h: Remove unused u128 type
correctly set Activation matrix rows, avoid passing as arg
ops: pass Mat instead of pointers/sizes; vectorize LayerNorm; support any weight type
mat: add OverrideRows, used by SetBatchSize
PiperOrigin-RevId: 757790736
2025-05-12 09:16:12 -07:00
Jan Wassenberg 275135d7e8 Rename-only: remove Allocator2 etc suffixes now that refactoring is complete
PiperOrigin-RevId: 755397220
2025-05-06 09:12:43 -07:00
Jan Wassenberg 8d0882b966 Huge refactor of weight handling and model loading.
Weight handling:
- new ModelStore2 supports both pre-2025 multi-file and single-file formats
- simpler ForEachTensor with TensorArgs
- tensors are constructed with their full suffixed name

I/O:
- support mmap and stride
- Simplified SbsWriter, single insert(); add SbsReader

Misc:
- kMockTokenizer: allow creating with unavailable tokenizer
- configs.h: Simpler enum validity checks via kSentinel
- matmul.h: remove unused enable_bind (now in allocator.h)
- tensor_info: single TensorInfoRegistry class, rename from tensor_index.h

Frontends:
- Replace Allocate/CreateGemma with ctor(LoaderArgs, MatMulEnv&)
- Deduce model/weight type, remove --model and parsing
- Replace most common.h includes with configs.h
- Remove --compressed_weights, use --weights instead
- Remove ModelInfo, replaced by ModelConfig.

Backprop:
- Reduce max loss, remove backward_scalar_test (timeout)
- Update thresholds because new RandInit changes rng eval order and thus numerics
PiperOrigin-RevId: 755317484
2025-05-06 04:44:21 -07:00
Jan Wassenberg 8532da47f7 Major refactor of allocator/args:
use new ThreadingContext2 instead of monostate/init in each frontend
Add ThreadingArgs(replaces AppArgs)

backprop: use Packed() accessor and MakePacked factory and row-based access to allow for stride
compress_weights: remove, moving to py-only exporter instead

Move MatPtr to mat.h and revise interface:
- Generic MatOwner
- rename accessors to Packed*
- support stride/row accessors, fix RowPtr stride

Add TypeBits(Type)
Move GenerateMat to test_util-inl for sharing between matmul test/bench
Move internal init to gemma.cc to avoid duplication
Rename GemmaEnv model_ to gemma_ for disambiguating vs upcoming ModelStorage
Remove --compressed_weights, use --weights instead.
tensor_index: add ExtentsFromInfo and TensorIndexLLM/Img
Allocator: use normal unique_ptr for AllocBytes so users can call directly
threading: use -> because AlignedPtr no longer assumes arrays
PiperOrigin-RevId: 745918637
2025-04-10 01:29:54 -07:00
Jan Wassenberg 76a81ac2d6 Fix unaligned buffer causing crash on GCC. Thanks @ufownl, fixes #508
PiperOrigin-RevId: 741590339
2025-03-28 11:25:33 -07:00
Jan Wassenberg 2bdf26d81d Support bf16 output of Matmul
Adds Stride to ConstMat, to support decompression of C output for test
matmul_test: add line numbers to output
Also ignore "N is not a multiple of nc" when N==nc
PiperOrigin-RevId: 731096662
2025-02-25 17:53:20 -08:00
Jan Wassenberg f9d93e4a42 Matmul rewrite: fp64 sums, hierarchical parallelization, cache-blocking, autotuning
Remove empty matmul_unit_test.
Up to 25 TFLOP/s on 2xZen4 for 512,3072,24576.

PiperOrigin-RevId: 729123576
2025-02-20 08:33:46 -08:00
Jan Wassenberg f74d496879 Threading/infra improvements.
* Add Parallelize*Range helpers and partitioning helpers
* Refactor Pinning class, store original affinity (required to construct another NestedPools after pinning happened)

Compress:
* prevent Compress printing stats in tests
* zero-pad tensors

Matmul:
* add matmul_unit_test (TODO) and bench_matmul
* matmul_test: change norm to row vectors (that is what is added) and include bf16 rounding error
* Prepare for L2/L3 retrieval
PiperOrigin-RevId: 700603811
2024-11-27 01:12:00 -08:00
Jan Wassenberg 868b01601f Simpler MatMul interface, vocab types, Tristate for use_spinning
Add Extents2D, Range2D vocab types
Matmul uses ConstMat for inputs and RowPtr for output
Move RowVectorBatch to basics.h
Separate threading.cc
Fix topology string: report cores not LPs, and #HT
Move QStride/IsMHA into LayerConfig
ImageTokens does not require make_unique.
matmul_test: no longer require template args
PiperOrigin-RevId: 692963605
2024-11-04 07:48:29 -08:00
Jan Wassenberg 02ce1e344f Use NestedPools, add NUMA infra
Improved threading.h, fix thread counts for single package/cluster systems
Temporarily forces to a single socket. Prefill 29.28 tps, decode 6.92.

Also fix benchmarks.cc build, update tensor allocator to Allocator

PiperOrigin-RevId: 687307167
2024-10-18 08:11:18 -07:00
Jan Wassenberg 8c0a8834c1 Major compression update, arbitrary-len unpack + new Dot
Compression:
* Implement {any packed} x {bf16, f32} 'Load2' and DecompressAndZeroPad
* New compression test for all packed formats, add to GEMMA_TEST_FILES, remove from sfp/nuq_test
* Decompress->DecompressAndZeroPad, use PackedSpan for args with bounds checking
* NUQ: support arbitrary-length enc/dec
* New compression/shared, remove sfp.h and nuq.h
* Move Store2 into Traits and provide Compress2 wrapper
* Remove unused Decompress()-with-pool overload
* Simplify CompressedArrayLen, rename to CompressedArrayElements
* Remove unused DistortionStats b_l1_

Misc:
* Add compensated and Kahan dot, support any length
* Use same Dot function everywhere
* Move exact arithmetic functions into fp_arith
* use FloatPtr and MatPtr typedefs in tests; less stack usage
* Rename args to packed/raw
* Remove Traits::Name, instead TypeName<T>()
* Move kMaxSFP and kClusters/kGroupSize into Sfp/NuqStream
PiperOrigin-RevId: 672868468
2024-09-10 02:22:19 -07:00
Jan Wassenberg 301dc8067a Major MatMul update, 1.9-2.3x speedup on Zen4 via bf16 mul
Supports converting all weight/activation formats to native MulT (bf16/f32)

Also:
- ConstMat/MutableMat for const correctness
- Move RowVectorBatch to allocator.h so it can be used from Matmul
- Add matmul.h so MatMulEnv can be used from Activations
- Remove kMaxThreads, detect from PerClusterPools
- Build fix: -inl.h files must be textual_hdrs, and highway.h should precede -inl.h

```
zen4 new
64, 24576, 3072, add=0, MatTA=bf16, MatTB=sfp:   616.6 GFLOPS.
64, 3072, 24576, add=0, MatTA=bf16, MatTB=sfp:   460.7 GFLOPS.
64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp:    598.6 GFLOPS.
64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp:    435.6 GFLOPS.

zen4 old
64, 24576, 3072, add=0, MatTA=f32, MatTB=sfp:    257.5 GFLOPS.
64, 3072, 24576, add=0, MatTA=f32, MatTB=sfp:    231.9 GFLOPS.
```

PiperOrigin-RevId: 663729812
2024-08-16 07:52:20 -07:00
Jan Wassenberg 8e028632f7 0.98x prefill: refactor in prep for cache blocking.
Slower because we now init tiles of C and accumulate into them.

Also remove unused var in optimize_test and use BF16 typedef.

PiperOrigin-RevId: 662115916
2024-08-12 09:26:29 -07:00
Jan Wassenberg a24eda8d02 Split matmul into matvec; add large matrix benchmark
Rename var names to row/col for more clarity.
Better estimate error tolerance via max abs col sum.

PiperOrigin-RevId: 657601791
2024-07-30 08:29:11 -07:00
Jan Wassenberg 6ea4232b2e MatMul cleanup: Mat struct, simplify args.
Add large benchmark to test, use 4 threads, skip some targets.
Also use Traits::Name instead of typeid.

PiperOrigin-RevId: 657496185
2024-07-30 01:55:50 -07:00
Jan Wassenberg f27683152c 1.05x prefill speedup: matvec -> matmul for !MHA
Also add C_stride and make shape normal non-template arguments.

PiperOrigin-RevId: 657285945
2024-07-29 12:18:06 -07:00
Jan Wassenberg 2721f54446 Add offset arg to MatMul, rename, Matmul for logits = ~1.1x decode speedup
PiperOrigin-RevId: 657167257
2024-07-29 05:34:26 -07:00
Jan Wassenberg 85cac13fb1 Split up ops.h into ops/ops-inl and matmul-inl
PiperOrigin-RevId: 654068303
2024-07-19 11:21:48 -07:00
Renamed from gemma/ops.h (Browse further)