Commit Graph

13 Commits

Author SHA1 Message Date
Phil Culliton 8a6edff319 Base interleaved handling for 4.5-bit NUQ, specifically Enc, DecompressAndZeroPad, and Dec2. Includes tests.
PiperOrigin-RevId: 721821577
2025-01-31 10:35:32 -08:00
Jan Wassenberg 8c0a8834c1 Major compression update, arbitrary-len unpack + new Dot
Compression:
* Implement {any packed} x {bf16, f32} 'Load2' and DecompressAndZeroPad
* New compression test for all packed formats, add to GEMMA_TEST_FILES, remove from sfp/nuq_test
* Decompress->DecompressAndZeroPad, use PackedSpan for args with bounds checking
* NUQ: support arbitrary-length enc/dec
* New compression/shared, remove sfp.h and nuq.h
* Move Store2 into Traits and provide Compress2 wrapper
* Remove unused Decompress()-with-pool overload
* Simplify CompressedArrayLen, rename to CompressedArrayElements
* Remove unused DistortionStats b_l1_

Misc:
* Add compensated and Kahan dot, support any length
* Use same Dot function everywhere
* Move exact arithmetic functions into fp_arith
* use FloatPtr and MatPtr typedefs in tests; less stack usage
* Rename args to packed/raw
* Remove Traits::Name, instead TypeName<T>()
* Move kMaxSFP and kClusters/kGroupSize into Sfp/NuqStream
PiperOrigin-RevId: 672868468
2024-09-10 02:22:19 -07:00
Jan Wassenberg 07c34cb18a Further nuq_test speedups to prevent timeout
PiperOrigin-RevId: 670863385
2024-09-04 00:49:44 -07:00
Jan Wassenberg 9661b81c4b Fix NUQ for SVE - incorrect nibble packing
Also speed up test

PiperOrigin-RevId: 670625545
2024-09-03 10:59:01 -07:00
Jan Wassenberg aa11ddf5fc 1.22x NUQ compress speedup, fix out of bounds access, improve numerics
Also clarify the cost computation and move toward non-group-multiple num.

PiperOrigin-RevId: 670544245
2024-09-03 07:10:56 -07:00
Jan Wassenberg 4033ed9e78 Avoid duplication of RMSNorm, support all activation/weight types
Add test for RMSNorm
Rename VectorizedRopeAndMulBy -> RopeAndMulBy

Move test_util to util/

PiperOrigin-RevId: 668332927
2024-08-28 01:26:55 -07:00
Jan Wassenberg 5c3e5f7038 Remove no longer required stats.h - use Highway version instead
PiperOrigin-RevId: 640440379
2024-06-05 01:37:48 -07:00
Jan Wassenberg a939b5fc9f Update distortion.h to weighted average, add distortion_test.
More thorough checks in sfp_test and nuq_test.
nuq_test: use deterministic input generator.

PiperOrigin-RevId: 625602019
2024-04-17 01:44:19 -07:00
Jan Wassenberg a982ec1287 Move code to gemma/ so we can remove error-prone copybara: comments.
Also fix includes and Lint warnings.

PiperOrigin-RevId: 623127487
2024-04-09 04:45:42 -07:00
Jan Wassenberg 61e031fe98 Towards building tests without GUnit Refs #29
PiperOrigin-RevId: 618032987
2024-03-21 19:28:02 -07:00
Jan Wassenberg 24add61dd9 Fix SFP/NUQ for bf16 rounding in Highway
SFP: Avoid rounding twice, and more robust TestDot.
NUQ: also more robust SNR, minor touchups to header.

PiperOrigin-RevId: 618030096
2024-03-21 19:06:19 -07:00
Jan Wassenberg bb9b023502 Support Bazel builds. Fixes #16
Also fix nuq/sfp-inl: warning, cast, and disable SCALAR

PiperOrigin-RevId: 612704056
2024-03-04 22:07:25 -08:00
Austin Huang e29cd566cf initial commit 2024-02-21 03:31:22 +00:00