Jan Wassenberg
4033ed9e78
Avoid duplication of RMSNorm, support all activation/weight types
...
Add test for RMSNorm
Rename VectorizedRopeAndMulBy -> RopeAndMulBy
Move test_util to util/
PiperOrigin-RevId: 668332927
2024-08-28 01:26:55 -07:00
Jan Wassenberg
b6d0ca8a14
Minor followup: remainder handling is a single iteration
...
Also add profiler annotations.
PiperOrigin-RevId: 667883774
2024-08-27 01:19:44 -07:00
Apoorv Reddy
48d0801fb0
Vectorize Rope for qkv dim not evenly divisible by number of lanes.
...
PiperOrigin-RevId: 665776602
2024-08-21 02:22:22 -07:00
Apoorv Reddy
c6eb3b6f0d
VectorizedRopeAndMulBy.
...
~8x reduction (tested on few prompts) in Rope.
~3.8% prefill latency improvement.
~2.6% decode latency improvement.
PiperOrigin-RevId: 664650108
2024-08-18 23:17:01 -07:00
Jan Wassenberg
2ebbe4076f
1.03-1.08x decode speedup: precompute Rope theta, fuse
...
Split attention into functions, move into class.
Fuse Rope and MulBy, allow non-in-place version to avoid copy from q to KV.
Sink if() into MaybeLogitsSoftCap.
PiperOrigin-RevId: 661168418
2024-08-09 01:23:24 -07:00
Jan Wassenberg
85cac13fb1
Split up ops.h into ops/ops-inl and matmul-inl
...
PiperOrigin-RevId: 654068303
2024-07-19 11:21:48 -07:00