Commit Graph

23 Commits

Author SHA1 Message Date
Jan Wassenberg 56186193c1 Replace mt19937 with new generator to enable parallel sampling
Split it into immutable AesCtrEngine and RngStream
Also add RowSpan and Logits span

PiperOrigin-RevId: 803336423
2025-09-04 23:49:10 -07:00
Jan Wassenberg 5d1693e806 Internal change
PiperOrigin-RevId: 803083229
2025-09-04 10:31:20 -07:00
Phil Culliton d044801c1d Internal change
PiperOrigin-RevId: 794620076
2025-08-13 09:47:45 -07:00
Jan Wassenberg 33fbac0880 Exporter updates/fixes
PiperOrigin-RevId: 791046073
2025-08-04 22:36:33 -07:00
Jan Wassenberg ac0d751d20 Rename GetModelConfig->Config
PiperOrigin-RevId: 788506480
2025-07-29 10:18:12 -07:00
Jan Wassenberg e76e29ce11 De-singleton ThreadingContext so callers can pass in their own
weights.cc: fix BindB argument for bf16 tensors
threading_test: enable autotune
PiperOrigin-RevId: 785763618
2025-07-22 02:08:46 -07:00
Jan Wassenberg a04cc287b2 Move MatMulEnv out of Gemma to enable concurrent calls
Also update benchmark_helper config print: add profiler, remove free mem

PiperOrigin-RevId: 774662974
2025-06-23 01:20:09 -07:00
Jan Wassenberg e5c81f64a1 Major refactor: clarify query_idx (global) vs qi. Refs #607
Fix missing pos increment for last prefill and check that in gemma_test.
Thanks to @ufownl for pointing this out.

Change argument lists to QBatch with accessors.
Increase default seq_len to 8k.

PiperOrigin-RevId: 771937385
2025-06-16 02:42:02 -07:00
Jan Wassenberg c027a45a2e MatPtr-ify KV, shared div_seq_len, --seq_len flag
PiperOrigin-RevId: 770194455
2025-06-11 09:49:38 -07:00
Daniel Keysers 9f74a1a098 Fix a problem in run_example.py
PiperOrigin-RevId: 767017932
2025-06-04 00:42:57 -07:00
Jan Wassenberg 3890eb5412 Remove backprop/
Also remove MatPtrT::Packed(); use PackedScale1 instead where const, or Row(0).

PiperOrigin-RevId: 764243198
2025-05-28 07:01:17 -07:00
Jan Wassenberg 2038dfd9cc Minor: rename compression/shared -> types.h
PiperOrigin-RevId: 758199851
2025-05-13 06:53:21 -07:00
Jan Wassenberg 45ad847a41 Replace RowVectorBatch with MatStorageT
KVCache: add ctor required for MatStorageT, remove Create; bf_pre_ffw_rms_out -> pre_ffw_rms_out
optimize_test: larger vocab_size requires more steps
shared.h: Remove unused u128 type
correctly set Activation matrix rows, avoid passing as arg
ops: pass Mat instead of pointers/sizes; vectorize LayerNorm; support any weight type
mat: add OverrideRows, used by SetBatchSize
PiperOrigin-RevId: 757790736
2025-05-12 09:16:12 -07:00
Jan Wassenberg cf7dd80c17 Minor: mark command line flags as required
PiperOrigin-RevId: 757775369
2025-05-12 08:30:44 -07:00
Jan Wassenberg 252a4e955e Remove support for Gemma 1 and PaliGemma 1 models, superseded by (Pali)Gemma 2.
PiperOrigin-RevId: 756671308
2025-05-09 02:17:27 -07:00
Jan Wassenberg 275135d7e8 Rename-only: remove Allocator2 etc suffixes now that refactoring is complete
PiperOrigin-RevId: 755397220
2025-05-06 09:12:43 -07:00
Jan Wassenberg 8d0882b966 Huge refactor of weight handling and model loading.
Weight handling:
- new ModelStore2 supports both pre-2025 multi-file and single-file formats
- simpler ForEachTensor with TensorArgs
- tensors are constructed with their full suffixed name

I/O:
- support mmap and stride
- Simplified SbsWriter, single insert(); add SbsReader

Misc:
- kMockTokenizer: allow creating with unavailable tokenizer
- configs.h: Simpler enum validity checks via kSentinel
- matmul.h: remove unused enable_bind (now in allocator.h)
- tensor_info: single TensorInfoRegistry class, rename from tensor_index.h

Frontends:
- Replace Allocate/CreateGemma with ctor(LoaderArgs, MatMulEnv&)
- Deduce model/weight type, remove --model and parsing
- Replace most common.h includes with configs.h
- Remove --compressed_weights, use --weights instead
- Remove ModelInfo, replaced by ModelConfig.

Backprop:
- Reduce max loss, remove backward_scalar_test (timeout)
- Update thresholds because new RandInit changes rng eval order and thus numerics
PiperOrigin-RevId: 755317484
2025-05-06 04:44:21 -07:00
Jan Wassenberg 87a658b1c6 Minor cleanup, on-demand NUQ buffer allocation
threading_context: add profiler
compress-inl: add constexpr, on-demand alloc NUQ buffer
gemma_py: model->gemma
Move ScaleWeights to compress.cc
Move PromptWrapping to configs.h
PiperOrigin-RevId: 748347896
2025-04-16 10:49:43 -07:00
Jan Wassenberg 2e722f14f1 Add mmap support (not yet used)
Also: const-correct ArgsBase,
add assert to mat.h checking element_bytes_
BUILD deps update (:shared provides shared.h, not :sfp)
PiperOrigin-RevId: 746073312
2025-04-10 10:03:40 -07:00
Jan Wassenberg 8532da47f7 Major refactor of allocator/args:
use new ThreadingContext2 instead of monostate/init in each frontend
Add ThreadingArgs(replaces AppArgs)

backprop: use Packed() accessor and MakePacked factory and row-based access to allow for stride
compress_weights: remove, moving to py-only exporter instead

Move MatPtr to mat.h and revise interface:
- Generic MatOwner
- rename accessors to Packed*
- support stride/row accessors, fix RowPtr stride

Add TypeBits(Type)
Move GenerateMat to test_util-inl for sharing between matmul test/bench
Move internal init to gemma.cc to avoid duplication
Rename GemmaEnv model_ to gemma_ for disambiguating vs upcoming ModelStorage
Remove --compressed_weights, use --weights instead.
tensor_index: add ExtentsFromInfo and TensorIndexLLM/Img
Allocator: use normal unique_ptr for AllocBytes so users can call directly
threading: use -> because AlignedPtr no longer assumes arrays
PiperOrigin-RevId: 745918637
2025-04-10 01:29:54 -07:00
Daniel Keysers f173aa776e Add conversion tool for HF safetensors to gemma.cpp for PaliGemma.
PiperOrigin-RevId: 725990158
2025-02-12 03:47:43 -08:00
Oleh Prypin 82ca526c0c Remove `srcs_version` and `python_version` attributes, as they already default to `"PY3"`
PiperOrigin-RevId: 724122259
2025-02-06 16:51:11 -08:00
Daniel Keysers 7af2e70321 Add python wrappers for configs and inference.
Enable building compression/python/compression_test using bazel.
Add default image path for image_test and paligemma_test.

PiperOrigin-RevId: 720583438
2025-01-28 08:22:03 -08:00