Jan Wassenberg
2ebbe4076f
1.03-1.08x decode speedup: precompute Rope theta, fuse
...
Split attention into functions, move into class.
Fuse Rope and MulBy, allow non-in-place version to avoid copy from q to KV.
Sink if() into MaybeLogitsSoftCap.
PiperOrigin-RevId: 661168418
2024-08-09 01:23:24 -07:00
Jan Wassenberg
a24eda8d02
Split matmul into matvec; add large matrix benchmark
...
Rename var names to row/col for more clarity.
Better estimate error tolerance via max abs col sum.
PiperOrigin-RevId: 657601791
2024-07-30 08:29:11 -07:00
Jan Wassenberg
6ea4232b2e
MatMul cleanup: Mat struct, simplify args.
...
Add large benchmark to test, use 4 threads, skip some targets.
Also use Traits::Name instead of typeid.
PiperOrigin-RevId: 657496185
2024-07-30 01:55:50 -07:00
Jan Wassenberg
f27683152c
1.05x prefill speedup: matvec -> matmul for !MHA
...
Also add C_stride and make shape normal non-template arguments.
PiperOrigin-RevId: 657285945
2024-07-29 12:18:06 -07:00
Jan Wassenberg
2721f54446
Add offset arg to MatMul, rename, Matmul for logits = ~1.1x decode speedup
...
PiperOrigin-RevId: 657167257
2024-07-29 05:34:26 -07:00
Jan Wassenberg
85cac13fb1
Split up ops.h into ops/ops-inl and matmul-inl
...
PiperOrigin-RevId: 654068303
2024-07-19 11:21:48 -07:00