Commit Graph

9 Commits

Author SHA1 Message Date
Jan Wassenberg cb188d4a0e Fix RowT issue and improve Griffin (currently still broken)
Use type-safe MatPtrT via dynamic_cast, avoid/remove unsafe RowT
activations: Griffin tensors are now padded
Griffin: add batching support, fix conv1d_cache allocation
weights: bundle to TensorToRead, add kNoPad flag, fix SplitW1
const-correct fix for ForEachTensor
blob_store: move BlobIO2 to .cc and rename BlobIO
PiperOrigin-RevId: 760610094
2025-05-19 07:02:10 -07:00
Jan Wassenberg d538a6d6c6 Cleanup: remove unused kCyclic, remove 2 suffix
Also remove now unused allocator arg and fix warnings (cast, struct/class mismatch)

PiperOrigin-RevId: 758098495
2025-05-13 01:06:41 -07:00
Jan Wassenberg 45ad847a41 Replace RowVectorBatch with MatStorageT
KVCache: add ctor required for MatStorageT, remove Create; bf_pre_ffw_rms_out -> pre_ffw_rms_out
optimize_test: larger vocab_size requires more steps
shared.h: Remove unused u128 type
correctly set Activation matrix rows, avoid passing as arg
ops: pass Mat instead of pointers/sizes; vectorize LayerNorm; support any weight type
mat: add OverrideRows, used by SetBatchSize
PiperOrigin-RevId: 757790736
2025-05-12 09:16:12 -07:00
Jan Wassenberg 160a5824fb Cleanup: include fixes/comments, fix leak, vector reserve
Also remove unused RowSpan
configs.cc: Assign prompt wrapping to ModelConfig
configs.h: simplify EnumValid via sentinel

PiperOrigin-RevId: 750278497
2025-04-22 12:01:46 -07:00
Apoorv Reddy 780e376023 Add KVCache.DeepCopy() . Will be useful for implementing sampling functionality like beam sampling, parallel sampling, CoT Decoding (à la https://arxiv.org/abs/2402.10200)
PiperOrigin-RevId: 725156316
2025-02-10 04:10:29 -08:00
Daniel Keysers e54d9cbddd Fix Griffin model:
- use HalfRope position encodings
- zero-initialize the caches for each Generate at position 0

The lack of the latter made the tests in gemma_test dependent on each other.

PiperOrigin-RevId: 694509054
2024-11-08 08:30:53 -08:00
Ray Smith 0d68555f87 Eliminated TConfig.
Changed CompressedLayer and CompressedWeights to be constructed with an instance of a LayerConfig and WeightsConfig respectively.
Added CompressedModel to remove ByteStorageT and get rid of most of the type casting, as well as allowing the default destructor to be used and work properly.
Adjusted WeightsWrapper and ForwardLayer etc to match.
The only remaining template arg is the weight type.
This enables all the instantiations to be deleted, apart from one per type.
It also enables (but not yet done) the config to be stored in the blob file instead of having to be specified separately.
Reduces the size of the gemma_lib and weights shared libraries by a factor of 4.3 and 3.2 respectively.

PiperOrigin-RevId: 686870060
2024-10-17 05:04:22 -07:00
Jan Wassenberg aaf51898b6 Major revamp #2 of Prefill: fix token order, parallel for multi-query
- Allocate only the required KV caches and activation batch size
- Add flags for batch sizes
- Const-correct interface: Span of const int.
- Also clean up the KVCache arg to a span.
- Move kPrefillBatchSize into RuntimeConfig and remove related global constants.

PiperOrigin-RevId: 655893197
2024-07-25 03:28:55 -07:00
Jan Wassenberg 09a7e75ead Prep for sharding gemma.cc: split into kv_cache, tokenizer.
Move activations.h to backprop/ to make space for another activations.h.

PiperOrigin-RevId: 648744500
2024-07-02 09:31:06 -07:00