Commit Graph

7 Commits

Author SHA1 Message Date
Ray Smith 8a100c1e8d Added access to flash attention internals to TileFlashAttention4
PiperOrigin-RevId: 826011137
2025-10-30 06:50:05 -07:00
Jan Wassenberg 3cc0139ebb Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage.

Also do not require specific type for query_norm_scale,
update batch sizes for attention tensors,
more verbose Mat shape/type checks.

PiperOrigin-RevId: 824987689
2025-10-28 05:33:01 -07:00
Biruk Mammo 5a05857deb [Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`.
* Updates `QBatch` to hold  non-owning `MatPtr`s to the kv caches.
* Enables the `MatPtrT` default constructor for simpler initializations.
* Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor.

PiperOrigin-RevId: 824584177
2025-10-27 10:43:25 -07:00
Jan Wassenberg 3ed403e287 Major cleanup of profiler zones, add Caller annotation for all pool.Run
Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones
Add GCPP_ZONE helper
Add Caller argument to pool.Run to enable new stats
Remove most direct dependencies on ThreadPool, prefer ParallelFor

PiperOrigin-RevId: 822934530
2025-10-23 01:54:24 -07:00
Ray Smith fb6fa793f4 Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle.
Improved flash_attention to enable profiling using the new zones.

PiperOrigin-RevId: 819235421
2025-10-14 08:30:58 -07:00
Ray Smith 2f6cbde8ff Added a smaller tile size to flash attention for smaller batch sizes
PiperOrigin-RevId: 813226193
2025-09-30 05:49:20 -07:00
Ray Smith f10ac41a20 Added flash attention, with both a single-q function, and a register-tiled function.
The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine.

PiperOrigin-RevId: 804913784
2025-09-09 08:05:26 -07:00