Commit Graph

12 Commits

Author SHA1 Message Date
Krzysztof Rymski df162ead7c Implementation of tiled attention with bf16 and circular buffers which reduces memory requirements by 4x on longer context on gemma models.
It also supports better parallelism for small batch sizes / small models.
It also is able to utilize VDPBF16PS for nice 2x improvement on avx512

PiperOrigin-RevId: 874517319
2026-02-24 03:26:49 -08:00
Ray Smith 76d7951242 Added wheat_from_chaff_test to test the ability of a model to find a needle in a haystack of data.
Replaced flag with attention_impl to control which attention to run.

PiperOrigin-RevId: 869694868
2026-02-13 06:05:30 -08:00
Balazs Racz 596bdfe5af Separate monolithic gemma_lib library into more specific cc_library targets.
Creates new cc_library targets for :attention, :tensor_stats and :activations. Eliminates cyclic dependencies between these libraries.

PiperOrigin-RevId: 845690136
2025-12-17 03:31:16 -08:00
Krzysztof Rymski 6e5e4123f1 Internal changes
PiperOrigin-RevId: 837775282
2025-11-28 02:37:06 -08:00
Phil Culliton ab87807a4c Pre-compress query activations to BF16 before FlashAttention.
PiperOrigin-RevId: 826524997
2025-10-31 09:49:44 -07:00
Ray Smith 8a100c1e8d Added access to flash attention internals to TileFlashAttention4
PiperOrigin-RevId: 826011137
2025-10-30 06:50:05 -07:00
Jan Wassenberg 3cc0139ebb Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage.

Also do not require specific type for query_norm_scale,
update batch sizes for attention tensors,
more verbose Mat shape/type checks.

PiperOrigin-RevId: 824987689
2025-10-28 05:33:01 -07:00
Biruk Mammo 5a05857deb [Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`.
* Updates `QBatch` to hold  non-owning `MatPtr`s to the kv caches.
* Enables the `MatPtrT` default constructor for simpler initializations.
* Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor.

PiperOrigin-RevId: 824584177
2025-10-27 10:43:25 -07:00
Jan Wassenberg 3ed403e287 Major cleanup of profiler zones, add Caller annotation for all pool.Run
Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones
Add GCPP_ZONE helper
Add Caller argument to pool.Run to enable new stats
Remove most direct dependencies on ThreadPool, prefer ParallelFor

PiperOrigin-RevId: 822934530
2025-10-23 01:54:24 -07:00
Ray Smith fb6fa793f4 Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle.
Improved flash_attention to enable profiling using the new zones.

PiperOrigin-RevId: 819235421
2025-10-14 08:30:58 -07:00
Ray Smith 2f6cbde8ff Added a smaller tile size to flash attention for smaller batch sizes
PiperOrigin-RevId: 813226193
2025-09-30 05:49:20 -07:00
Ray Smith f10ac41a20 Added flash attention, with both a single-q function, and a register-tiled function.
The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine.

PiperOrigin-RevId: 804913784
2025-09-09 08:05:26 -07:00