Commit Graph

467 Commits

Author SHA1 Message Date
Jan Wassenberg 73c3627b67 Add tensor stats and output
tensor_info: add missing header
io: fix mode
weights.h: add layer_idx to LayerWeightsPtrs
PiperOrigin-RevId: 843531051
2025-12-11 22:52:46 -08:00
Martin Stolle 78deacc357 Make attention configurable on the command line.
PiperOrigin-RevId: 842760721
2025-12-10 09:34:06 -08:00
Martin Stolle 2441ff01bf internal change
PiperOrigin-RevId: 842749037
2025-12-10 09:01:15 -08:00
Martin Stolle 9689fc82f9 internal change
PiperOrigin-RevId: 842205671
2025-12-09 06:17:08 -08:00
Krzysztof Rymski 64d700cab5 Internal changes
PiperOrigin-RevId: 842194766
2025-12-09 05:42:03 -08:00
Martin Stolle 14a9ecf21d Factor out SumHeads
PiperOrigin-RevId: 842138081
2025-12-09 02:23:16 -08:00
Martin Stolle 1014ae9e2a Adding a simple test for GemmaAttention
PiperOrigin-RevId: 842135414
2025-12-09 02:13:03 -08:00
Martin Stolle b510ba2ab2 Improve clarity of indices II
Sorry, didn't see this one before.

PiperOrigin-RevId: 840218378
2025-12-04 06:33:33 -08:00
Martin Stolle 9348048885 Clean up toPtrs to delegate to toPtr
PiperOrigin-RevId: 840214969
2025-12-04 06:22:04 -08:00
Martin Stolle d2090fddf3 Improve clarity of indices
PiperOrigin-RevId: 839805634
2025-12-03 10:11:21 -08:00
Jan Wassenberg a084d33e41 Fix Gemma3 image: ensure A matrix is packed, preallocate
Also ignore -2 tokens

PiperOrigin-RevId: 838869988
2025-12-01 11:47:23 -08:00
Krzysztof Rymski 6e5e4123f1 Internal changes
PiperOrigin-RevId: 837775282
2025-11-28 02:37:06 -08:00
Krzysztof Rymski c153d5255b Internal changes
PiperOrigin-RevId: 837001762
2025-11-26 01:05:35 -08:00
Martin Stolle 8696f6dd17 Clarify indices
PiperOrigin-RevId: 836235539
2025-11-24 08:27:59 -08:00
Jan Wassenberg 37a25c9ffe Fix warning (signed vs unsigned)
PiperOrigin-RevId: 836106478
2025-11-24 00:51:17 -08:00
Charles Zhao 0e5f4cbf1b Implement Continus Batching.
(1) A function GenerateTWithContinuousBatching is added to use continuous batching when enabled.

(2) The ContinuousQBatch is added as a subclass of QBatch to manage prefill, insert, used-kv-cache-collection.

(3) Also expanded the unit test to more diverse cases.

PiperOrigin-RevId: 836090261
2025-11-23 23:54:02 -08:00
Martin Stolle 88a03b7ec4 Added access to softmax attention internals to regular attention
PiperOrigin-RevId: 835244205
2025-11-21 09:01:01 -08:00
Martin Stolle 5a500872b8 Internal change
PiperOrigin-RevId: 835115693
2025-11-21 01:17:45 -08:00
Martin Stolle 49d420aeaf Add some comments.
PiperOrigin-RevId: 834173319
2025-11-19 01:09:15 -08:00
The gemma.cpp Authors b8f6be72b1 Improves autodetection of Gemma3-1B.
Uses the key_norm and query_norm layers to disambiguate between the Gemma2-2B and Gemma3-1B models.
Since Gemma3-1B is not multimodal, ViT is not an effective disambiguator. KQ normalization is a structural disambiguator between gemma2 and gemma3.

PiperOrigin-RevId: 833213331
2025-11-17 01:12:50 -08:00
Jan Wassenberg 3e18db17f4 Avoid hard-coding kPatchSize. Thanks @Somet2mes for reporting. Fixes #762.
PiperOrigin-RevId: 829308896
2025-11-07 00:32:31 -08:00
Charles Zhao f8131339a7 Refactor for continous batching. This cl does not change the current behavior of the code. It only extract two functions that will later be called for adding continuous batching.
PiperOrigin-RevId: 829104661
2025-11-06 14:20:17 -08:00
Martin Stolle 35e9f9f05f Introduce attention implementation configurability.
PiperOrigin-RevId: 828971705
2025-11-06 08:43:41 -08:00
Jan Wassenberg 091b4567c9 Minor: ParallelismStrategy->Parallelism
PiperOrigin-RevId: 828936578
2025-11-06 06:56:10 -08:00
Jan Wassenberg a344a70c59 Change (old) attention behavior to disallow wraparound, enforced via assertion.
Shared kU64PerLine constant

PiperOrigin-RevId: 828072451
2025-11-04 11:52:40 -08:00
Charles Zhao 3a63a12624 Allow prefill only run by allowing max_prompt_size == seq_len
PiperOrigin-RevId: 827415258
2025-11-03 03:17:54 -08:00
Phil Culliton ab87807a4c Pre-compress query activations to BF16 before FlashAttention.
PiperOrigin-RevId: 826524997
2025-10-31 09:49:44 -07:00
Ray Smith 8a100c1e8d Added access to flash attention internals to TileFlashAttention4
PiperOrigin-RevId: 826011137
2025-10-30 06:50:05 -07:00
Phil Culliton 116cd6eff6 BF16 mixed-mode flash attention
PiperOrigin-RevId: 825433929
2025-10-29 01:48:28 -07:00
Jan Wassenberg 4bd465ffd3 Also update attention.h to type-erased query_norm_scale
PiperOrigin-RevId: 825014334
2025-10-28 06:48:33 -07:00
Jan Wassenberg 3cc0139ebb Fix excessive KC/MC from prior change
This could lead to stack overflow in B_storage.

Also do not require specific type for query_norm_scale,
update batch sizes for attention tensors,
more verbose Mat shape/type checks.

PiperOrigin-RevId: 824987689
2025-10-28 05:33:01 -07:00
Biruk Mammo 5a05857deb [Gemma.cpp] Allows non-owned arguments for attention methods.
* Adds and uses a new `AttentionActivationPtrs` that holds non-owning `MatPtrs`. Acts as a view into `AttentionActivations`.
* Updates `QBatch` to hold  non-owning `MatPtr`s to the kv caches.
* Enables the `MatPtrT` default constructor for simpler initializations.
* Pulls out and passes `LayerWeightsPtrs::query_norm_scale` directly. While `LayerWeightsPtrs` already held non-owning `MatPtr`s, this change avoids the need to find and construct several empty weight tensors just to construct one `query_norm_scale` tensor.

PiperOrigin-RevId: 824584177
2025-10-27 10:43:25 -07:00
Theotime Combes 1bdde1af3c Add config flag for global timescale & rely on config to deduce wrapping
PiperOrigin-RevId: 823512377
2025-10-24 06:54:56 -07:00
Jan Wassenberg 3ed403e287 Major cleanup of profiler zones, add Caller annotation for all pool.Run
Pass ThreadingContext instead of Pools/Profiler individually, for access to Zones
Add GCPP_ZONE helper
Add Caller argument to pool.Run to enable new stats
Remove most direct dependencies on ThreadPool, prefer ParallelFor

PiperOrigin-RevId: 822934530
2025-10-23 01:54:24 -07:00
Phil Culliton 503aaddd65 Add 8-bit integer quantization (I8Stream) to Gemma.cpp.
PiperOrigin-RevId: 819787856
2025-10-15 09:25:20 -07:00
Ray Smith ee18916abf Removed the PROFILER_ZONE from the most highly called functions to reduce the overhead.
PiperOrigin-RevId: 819739402
2025-10-15 07:10:04 -07:00
Ray Smith fb6fa793f4 Added a global (to gemma) zones list to enable most call sites to PROFILER_ZONE3 to avoid the sychronization required for the static const initialization of the zone handle.
Improved flash_attention to enable profiling using the new zones.

PiperOrigin-RevId: 819235421
2025-10-14 08:30:58 -07:00
Jan Wassenberg 035273c184 tune pool kSpin mode in threading_context
Previously, this happened concurrently with the matmul autotune, which could lead to incorrect outcomes.

threading: de-singleton Pinning (no longer stores affinity); pass PoolWorkerMapping; fix Pool dtor order
Also enable SPR target (Zen4 is AMD-only),
update Highway version for renamed Thread()->GlobalIdx().
PiperOrigin-RevId: 816223017
2025-10-07 08:36:26 -07:00
Ray Smith 684a0444e9 Reduced parallelism for TransposeQ, making each thread read and write within its own cache lines
PiperOrigin-RevId: 814241032
2025-10-02 08:15:16 -07:00
Ray Smith 14244664c8 Avoid transposing Q when it isn't needed
PiperOrigin-RevId: 814187984
2025-10-02 05:16:35 -07:00
Jan Wassenberg fe5a39990e Improve FlashAttention threading:
kFlat for RMSNorm (hierarchical is excessive),
profiler zone naming improvements.

PiperOrigin-RevId: 814144012
2025-10-02 02:37:05 -07:00
Ray Smith 6098a022b3 Increased parallelism for RMSNormAndPositionalEncoding
PiperOrigin-RevId: 813738994
2025-10-01 07:11:14 -07:00
Ray Smith 2f6cbde8ff Added a smaller tile size to flash attention for smaller batch sizes
PiperOrigin-RevId: 813226193
2025-09-30 05:49:20 -07:00
Ray Smith 4974f24832 Fixed bug with softcap in single flash attention
PiperOrigin-RevId: 813164938
2025-09-30 02:17:58 -07:00
Nitin Gangahar 667a3f117a Utilize multiple cores to read weight batches.
PiperOrigin-RevId: 811893059
2025-09-26 11:28:33 -07:00
Charles Zhao 4f0c633248 (1) Added QueryResultAndMetrics and BatchQueryModelWithMetrics to also return TimingInfo besides query results.
PiperOrigin-RevId: 810634261
2025-09-23 17:02:29 -07:00
Jan Wassenberg fac8aac4cb Internal change
PiperOrigin-RevId: 809975026
2025-09-22 05:37:03 -07:00
Jan Wassenberg 501fdf000e Remove no longer used MatVec
PiperOrigin-RevId: 809059409
2025-09-19 09:03:22 -07:00
Jan Wassenberg f3bc1c17da 1.03x speedup: fused FFN
matmul-inl: support CView=StridedView or RowPtrs; rename to C_MC_NC
matmul.cc: Allow 1 more rep for MC/NC to allow half-sized tiles, which helps.
PiperOrigin-RevId: 807291701
2025-09-15 10:26:37 -07:00
Ray Smith c9b8479f7d Added zero-initialization to att_out.
Re-enabled flash attention when HWY_NATIVE_DOT_BF16 is not available.

PiperOrigin-RevId: 806284756
2025-09-12 07:48:23 -07:00
Jan Wassenberg 2695aab5d2 Temporarily disable flash pending msan fix
PiperOrigin-RevId: 805350234
2025-09-10 07:25:41 -07:00
Jan Wassenberg ba6131311a Fix gemma_batch_bench for flash attention
q_T rows do not change.
Also repeat prefill to reflect perf after autotuning.

PiperOrigin-RevId: 805319377
2025-09-10 05:32:34 -07:00
Ray Smith f10ac41a20 Added flash attention, with both a single-q function, and a register-tiled function.
The register-tiled version achieves a speed-up by a factor of about 9.7 over the previous attention function on an AVX3-enabled machine.

PiperOrigin-RevId: 804913784
2025-09-09 08:05:26 -07:00
Jan Wassenberg 461a9c7d1b Matmul refactoring towards fusion
MMLoops: move dispatch code out, use overloads
split build target into matmul_env (for MatMulEnv/MMOptions)
weights: no longer call BindB
Fix potential out of bounds in gemma_batch_bench
PiperOrigin-RevId: 804895985
2025-09-09 07:13:38 -07:00
Jan Wassenberg a5ab99e4ba Memory use reduction: smaller/single MMStorage
PiperOrigin-RevId: 804865029
2025-09-09 05:32:46 -07:00
Jan Wassenberg 6e52a835c6 Faster startup on tsan: use hierarchical parallelism for BF16 conversion
Also re-enable profiler zones

PiperOrigin-RevId: 804273899
2025-09-07 22:50:31 -07:00
Jan Wassenberg cbe24eac51 1.15x speedup: parallel sampling, enabled by new RNG
Also pass pos to SampleFunc, for seeding the RNG.

PiperOrigin-RevId: 803453518
2025-09-05 07:24:02 -07:00
Jan Wassenberg 2b4c16e243 Remove Griffin support
Also add IsObsolete helper

PiperOrigin-RevId: 803376921
2025-09-05 02:35:40 -07:00
Jan Wassenberg 56186193c1 Replace mt19937 with new generator to enable parallel sampling
Split it into immutable AesCtrEngine and RngStream
Also add RowSpan and Logits span

PiperOrigin-RevId: 803336423
2025-09-04 23:49:10 -07:00
Jan Wassenberg 5d1693e806 Internal change
PiperOrigin-RevId: 803083229
2025-09-04 10:31:20 -07:00
Jan Wassenberg 4be4799727 Remove kMaxPackages and per-package-related code
matmul: remove kMaxClusters, dynamic allocation
PiperOrigin-RevId: 802950348
2025-09-04 03:33:12 -07:00
Jan Wassenberg 7263ab8445 MatMul simplification, threading strategy improvements
remove MatMul f32 special case (smaller code),
types: Add u32/u64 for use by Activations
move renamed ParallelismStrategy to threading_context so can pass ctx
ensure worker index is unique across clusters
matmul.h: const member functions for renamed policy classes (easier to call)
PiperOrigin-RevId: 802848086
2025-09-03 21:45:07 -07:00
Jan Wassenberg b7b3d353db Simplify MatMul: remove F32 special case (build time)
Also move kMaxM into separate kMaxBatchSize

PiperOrigin-RevId: 802086590
2025-09-02 04:29:21 -07:00
Jan Wassenberg 1e3c853e80 Add ParallelFor wrapper function and one new mode
Move ParallelismType from matmul.h to threading.h
Replace SmallParallelFor with ParallelFor and the new mode

PiperOrigin-RevId: 802038452
2025-09-02 01:40:09 -07:00
Jan Wassenberg 229bd078a1 1.29x speedup: bf16 C1/C2. Extend most ops to any type, expand test coverage.
Also increase dot_test.cc range for Zen4, and matmul_test tolerance (failing in some configs)

PiperOrigin-RevId: 801789922
2025-09-01 06:34:04 -07:00
Jan Wassenberg 0ae8646731 Fix remainder handling for Paligemma
No longer attempt to skip the remainder handling because B might also be a non-padded view.

PiperOrigin-RevId: 800890805
2025-08-29 07:25:52 -07:00
Marie White 973e284ed6 Refactor Matmul to use a policy class for parallelization.
PiperOrigin-RevId: 800864489
2025-08-29 05:40:39 -07:00
Jan Wassenberg 6c39a2dea4 1.01x speedup: More bf16 activations to reduce DecompressA.
Also move observer call into function, format gemma_args.

PiperOrigin-RevId: 800827400
2025-08-29 03:19:01 -07:00
Jan Wassenberg 7288891439 Remove F64 partial storage in matmul.
Also remove no longer used kMaxN; row_ptrs only used for C

PiperOrigin-RevId: 800774757
2025-08-29 00:12:08 -07:00
Jan Wassenberg 98ddc166db Expand ThreadingContext comments
PiperOrigin-RevId: 800479954
2025-08-28 08:32:10 -07:00
Marie White 6128e758ff Change ffw_out from B16 to F32.
PiperOrigin-RevId: 800330411
2025-08-28 00:01:39 -07:00
Jan Wassenberg 5411fd846d Minor: batched NotifyGenerate, fix comment/dep
PiperOrigin-RevId: 799889802
2025-08-26 23:33:17 -07:00
Jan Wassenberg 86afd53076 1.04x speedup: Parallelize SoftCap
Also require opt-in constexpr flag for observer callbacks, update zones

PiperOrigin-RevId: 799655163
2025-08-26 11:55:20 -07:00
Jan Wassenberg ed2f0bd1b0 Fix pos assertions, refs #665
Ensure the streaming func pos matches the number of calls.
Add two arguments that control pos+1 and pos+=1 behavior.
Also cleanup/add comments.
run: use batch_stream_func, add assert, higher verbosity for MM autotune output
PiperOrigin-RevId: 799511163
2025-08-26 04:50:40 -07:00
Jan Wassenberg 9bf0fe4e37 Internal change
PiperOrigin-RevId: 799509375
2025-08-26 04:44:08 -07:00
Jan Wassenberg d3a5ddf657 Merge pull request #663 from junjihashimoto:feature/api-server
PiperOrigin-RevId: 797731089
2025-08-24 11:57:05 +02:00
Rhett Stucki 73f1140dca Fix an off-by-one error after StreamAndUpdateEOS() to remove the MSAN warning about reading an uninitialized variable in the kv_cache.
The logic for choosing whether or not to attend to the last token during prefill wasn't completely consistent with StreamAndUpdateEOS(), causing an off-by-one error that prevented the kv_cache from being fully populated.

PiperOrigin-RevId: 797614310
2025-08-20 22:59:58 -07:00
Junji Hashimoto 41321611fd feature: add API server and client with Google protocol 2025-08-21 11:32:48 +09:00
Phil Culliton 78573b6718 Internal change. Add deduction for 270M.
PiperOrigin-RevId: 795041810
2025-08-14 08:04:38 -07:00
Phil Culliton d044801c1d Internal change
PiperOrigin-RevId: 794620076
2025-08-13 09:47:45 -07:00
Jan Wassenberg 71406cf6d0 More profiler interface fixes: hwy:: plus avoid ADD_ZONE
PiperOrigin-RevId: 794493165
2025-08-13 03:15:48 -07:00
Jan Wassenberg faa4102992 (Resubmit) Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 794461159
2025-08-13 01:38:24 -07:00
The gemma.cpp Authors a2d9133f7d Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 793865287
2025-08-11 17:51:38 -07:00
Jan Wassenberg 4cbf63e6f0 Prepare profiler annotations for new API
Pass hwy::Profiler& to low-level functions.
Used ThreadingContext arg instead of NestedPools.
Use new PROFILER_ZONE3.

PiperOrigin-RevId: 793821255
2025-08-11 15:34:52 -07:00
Jan Wassenberg 4e062d68f7 Update BlobWriter comments, WriteAll->Finalize
PiperOrigin-RevId: 790792133
2025-08-04 10:01:38 -07:00
Jan Wassenberg 701841897b Default to disabling per-socket parallelization
weights: default to Read for small-batch (only look at qbatch, not the larger prefill tbatch)
PiperOrigin-RevId: 790787643
2025-08-04 09:49:14 -07:00
Jan Wassenberg 799c264df3 Pre-tune thread pool before matmul
Also improve profiler annotations - remove near-zero ones and add more for startup

PiperOrigin-RevId: 789352414
2025-07-31 08:45:26 -07:00
Charles Zhao 50ee1a3e92 Write SBS progressively.
(1) Directly write to file in BlobWriter::Add and destruct the MatOwner to release the rams.

(2) Write a fake header to indicate this is V2, and write correct header and directory at the end of the file.

(3) Tested on loading sbs written the old way, and new way, both worked.

PiperOrigin-RevId: 789306837
2025-07-31 06:05:38 -07:00
Jan Wassenberg 8715eda512 Improved layer idx parsing
PiperOrigin-RevId: 788868522
2025-07-30 05:49:45 -07:00
Jan Wassenberg d831ddce5b Fix file mapping: was letting the smart pointer go out of scope
Also save+print the IO mode used.

PiperOrigin-RevId: 788848165
2025-07-30 04:30:10 -07:00
Jan Wassenberg d22ba2ac96 Update layer index parsing and allow tokenizer override
PiperOrigin-RevId: 788797948
2025-07-30 01:22:34 -07:00
Jan Wassenberg d1638587f0 1.14x batch decode speedup: parallelize RMSNorm ops
Activations was over-parallelized, use single pool instead.
Also improve profiler zone annotations,
pass through worker args (for tracking concurrency), now non-optional.

PiperOrigin-RevId: 788790976
2025-07-30 00:55:45 -07:00
Jan Wassenberg ac0d751d20 Rename GetModelConfig->Config
PiperOrigin-RevId: 788506480
2025-07-29 10:18:12 -07:00
Jeremiah Harmsen 33fabd4ed1 Internal change.
PiperOrigin-RevId: 788463042
2025-07-29 08:21:29 -07:00
Jan Wassenberg e76e29ce11 De-singleton ThreadingContext so callers can pass in their own
weights.cc: fix BindB argument for bf16 tensors
threading_test: enable autotune
PiperOrigin-RevId: 785763618
2025-07-22 02:08:46 -07:00
Jan Wassenberg 5474146129 Back to f32 kv_cache, but via typedef
PiperOrigin-RevId: 785422614
2025-07-21 07:05:35 -07:00
Jan Wassenberg 56c9196eb6 Add blob_path to config deduction message
PiperOrigin-RevId: 782188689
2025-07-11 18:58:56 -07:00
Jan Wassenberg 4bc44d5678 Minor: ModelWeightsPtrs -> WeightsPtrs
PiperOrigin-RevId: 781954533
2025-07-11 06:11:51 -07:00
Jan Wassenberg a04cc287b2 Move MatMulEnv out of Gemma to enable concurrent calls
Also update benchmark_helper config print: add profiler, remove free mem

PiperOrigin-RevId: 774662974
2025-06-23 01:20:09 -07:00
Jan Wassenberg 0f70f285e0 1.1x prefill and decode speedup (attention/activations)
Optimizations
- Better load-balancing in attention threading
(Previously, clusters were limited by #heads)
- Add MulByConstTo to avoid zero-init
- Parallel activations

Cleanup
- Prepare for RowPtr in A or B
- Pass through thread_id to ops
- Avoid warning in bench_matmul

PiperOrigin-RevId: 773723423
2025-06-20 08:59:53 -07:00