Commit Graph

8 Commits

Author SHA1 Message Date
Daniel Keysers d83ad76679 Rename one variable in SampleTopK and update TestSampleTopK.
PiperOrigin-RevId: 680897787
2024-10-01 00:51:33 -07:00
Daniel Keysers 03f0ee2323 Add tests for SampleTopK that highlight existing problems and fix those:
- Sampling was not correct for k>1 and temperature=0.
- Sampling was not correct for only negative logits.

Also restructure the code a bit for better readability and add some asserts for things that shouldn't happen.

PiperOrigin-RevId: 676043267
2024-09-18 10:32:01 -07:00
Daniel Keysers 892f3bbcbe Implement scalar version of LayerNorm
PiperOrigin-RevId: 675085495
2024-09-16 03:54:10 -07:00
Jan Wassenberg 8c0a8834c1 Major compression update, arbitrary-len unpack + new Dot
Compression:
* Implement {any packed} x {bf16, f32} 'Load2' and DecompressAndZeroPad
* New compression test for all packed formats, add to GEMMA_TEST_FILES, remove from sfp/nuq_test
* Decompress->DecompressAndZeroPad, use PackedSpan for args with bounds checking
* NUQ: support arbitrary-length enc/dec
* New compression/shared, remove sfp.h and nuq.h
* Move Store2 into Traits and provide Compress2 wrapper
* Remove unused Decompress()-with-pool overload
* Simplify CompressedArrayLen, rename to CompressedArrayElements
* Remove unused DistortionStats b_l1_

Misc:
* Add compensated and Kahan dot, support any length
* Use same Dot function everywhere
* Move exact arithmetic functions into fp_arith
* use FloatPtr and MatPtr typedefs in tests; less stack usage
* Rename args to packed/raw
* Remove Traits::Name, instead TypeName<T>()
* Move kMaxSFP and kClusters/kGroupSize into Sfp/NuqStream
PiperOrigin-RevId: 672868468
2024-09-10 02:22:19 -07:00
Jan Wassenberg 5c0da8c8c3 Minor cleanup/fixes:
- optimize_test simplify prompt check
- Fix SFP arg case
- Fix includes
- Align inputs in test
- IsInside: add DASSERT
- Fix PerClusterPool NumThreads

PiperOrigin-RevId: 672530385
2024-09-09 06:58:09 -07:00
Jan Wassenberg 4033ed9e78 Avoid duplication of RMSNorm, support all activation/weight types
Add test for RMSNorm
Rename VectorizedRopeAndMulBy -> RopeAndMulBy

Move test_util to util/

PiperOrigin-RevId: 668332927
2024-08-28 01:26:55 -07:00
Apoorv Reddy c6eb3b6f0d VectorizedRopeAndMulBy.
~8x reduction (tested on few prompts) in Rope.
~3.8% prefill latency improvement.
~2.6% decode latency improvement.

PiperOrigin-RevId: 664650108
2024-08-18 23:17:01 -07:00
Jan Wassenberg 85cac13fb1 Split up ops.h into ops/ops-inl and matmul-inl
PiperOrigin-RevId: 654068303
2024-07-19 11:21:48 -07:00