Commit Graph

18 Commits

Author SHA1 Message Date
Jan Wassenberg f9d93e4a42 Matmul rewrite: fp64 sums, hierarchical parallelization, cache-blocking, autotuning
Remove empty matmul_unit_test.
Up to 25 TFLOP/s on 2xZen4 for 512,3072,24576.

PiperOrigin-RevId: 729123576
2025-02-20 08:33:46 -08:00
Jan Wassenberg a60b564b88 Infra improvements (2)
ops.h: move CreateInvTimescale to allow calling without depending on gemma
Pass around MatMulEnv instead of pools to avoid re-creating the env
profiler.h can now be used outside SIMD code
allocator: add StepBytes and QuantumSteps
rename worker thread with package/cluster in the name
threading: add Visit* to IndexRange
PiperOrigin-RevId: 718766704
2025-01-23 01:55:19 -08:00
Ray Smith 9d40f0117e Added ability to load/save a complete model file, including tokenizer.
PiperOrigin-RevId: 707914366
2024-12-19 07:59:41 -08:00
Jan Wassenberg f74d496879 Threading/infra improvements.
* Add Parallelize*Range helpers and partitioning helpers
* Refactor Pinning class, store original affinity (required to construct another NestedPools after pinning happened)

Compress:
* prevent Compress printing stats in tests
* zero-pad tensors

Matmul:
* add matmul_unit_test (TODO) and bench_matmul
* matmul_test: change norm to row vectors (that is what is added) and include bf16 rounding error
* Prepare for L2/L3 retrieval
PiperOrigin-RevId: 700603811
2024-11-27 01:12:00 -08:00
Jan Wassenberg 868b01601f Simpler MatMul interface, vocab types, Tristate for use_spinning
Add Extents2D, Range2D vocab types
Matmul uses ConstMat for inputs and RowPtr for output
Move RowVectorBatch to basics.h
Separate threading.cc
Fix topology string: report cores not LPs, and #HT
Move QStride/IsMHA into LayerConfig
ImageTokens does not require make_unique.
matmul_test: no longer require template args
PiperOrigin-RevId: 692963605
2024-11-04 07:48:29 -08:00
Daniel Keysers b7eff19be4 Update expected ranges in dot_test.
PiperOrigin-RevId: 685591625
2024-10-13 23:47:20 -07:00
Daniel Keysers 1eb9ce19dd Update expected ranges in dot_test.
PiperOrigin-RevId: 684515143
2024-10-10 11:31:19 -07:00
Jan Wassenberg 6ab3ff5bde Minor cleanup, Windows+Bazel build fixes
add app.h comment
compress-inl: remove unused typedef
gemma-inl: add missing HWY_ATTR and cast
separate sum-inl.h and basics.h headers
replace more hwy::bfloat16_t with BF16
update include pragmas
update dot_test thresholds
update Highway version in Bazel for HWY_RCAST_ALIGNED fix
PiperOrigin-RevId: 684464326
2024-10-10 09:05:06 -07:00
Jan Wassenberg 2c28b18eb0 Add NestedPools: one per socket/cluster
Use in dot_test
app.h: add new flags and rename num_threads to max_threads
matmul: Parallelize MatMulSlow and enable spinning, more large/fewer medium test cases
PiperOrigin-RevId: 683216386
2024-10-07 09:40:19 -07:00
Jan Wassenberg 5a71d819cb Also enable f64 dot/sum for <f32 inputs
Add bf16 support to Dot/SumKernelDouble in the same way as *Compensated.

PiperOrigin-RevId: 682308683
2024-10-04 07:12:10 -07:00
Jan Wassenberg 5e812f07f5 Use f64 Dot and sum in softmax - faster than Cascaded
Also let the kernel specify the Raw and State types,
rename WeightT/VecT -> WT/VT.

PiperOrigin-RevId: 680464427
2024-09-30 01:22:09 -07:00
Jan Wassenberg 47eb80a90e Add double-precision dot variant
PiperOrigin-RevId: 679243590
2024-09-26 12:09:10 -07:00
Daniel Keysers 2290eb7d3f Reduce flakiness of dot_test.
PiperOrigin-RevId: 679049273
2024-09-26 01:54:27 -07:00
Jan Wassenberg e70e686805 Add forward and backward error
PiperOrigin-RevId: 678297584
2024-09-24 10:10:29 -07:00
Jan Wassenberg 35fdf848c7 Cascaded summation for Softmax
This can affect generation results after a few hundred tokens.

Also remove profiler from DecompressAndCall, use Add instead of +=,
use PackedSpan for args and remove alignment requirement.
Changing accumulation order in AssimilateCascadedSums updates dot_test thresholds.

PiperOrigin-RevId: 676891797
2024-09-20 10:31:23 -07:00
Jan Wassenberg bb6b398df3 Add pairwise sum dot products for testing
Also add wrapper function for threshold comparison.

PiperOrigin-RevId: 676749760
2024-09-20 01:48:52 -07:00
Jan Wassenberg 8c0a8834c1 Major compression update, arbitrary-len unpack + new Dot
Compression:
* Implement {any packed} x {bf16, f32} 'Load2' and DecompressAndZeroPad
* New compression test for all packed formats, add to GEMMA_TEST_FILES, remove from sfp/nuq_test
* Decompress->DecompressAndZeroPad, use PackedSpan for args with bounds checking
* NUQ: support arbitrary-length enc/dec
* New compression/shared, remove sfp.h and nuq.h
* Move Store2 into Traits and provide Compress2 wrapper
* Remove unused Decompress()-with-pool overload
* Simplify CompressedArrayLen, rename to CompressedArrayElements
* Remove unused DistortionStats b_l1_

Misc:
* Add compensated and Kahan dot, support any length
* Use same Dot function everywhere
* Move exact arithmetic functions into fp_arith
* use FloatPtr and MatPtr typedefs in tests; less stack usage
* Rename args to packed/raw
* Remove Traits::Name, instead TypeName<T>()
* Move kMaxSFP and kClusters/kGroupSize into Sfp/NuqStream
PiperOrigin-RevId: 672868468
2024-09-10 02:22:19 -07:00
Jan Wassenberg 2308514e5a Experiment with compensated dot product.
ULP difference vs exact is 0..1, vs 200-5000 for previous.
Runtime overhead is 2.5-4x for f32 input.

PiperOrigin-RevId: 668084019
2024-08-27 12:05:35 -07:00