Zoltan Szabadka
465998d25a
Add support for custom sampling function to runtime config.
...
With this addition the ComputeCrossEntropy function can be moved
to its own library, because now we can compute it using only the
public API functions from gemma.h
2024-06-07 11:45:07 +00:00
Copybara-Service
f7ac7092d6
Merge pull request #212 from szabadka:adam2
...
PiperOrigin-RevId: 641182573
2024-06-07 02:25:18 -07:00
Jan Wassenberg
e3f4374e81
Fix fix for weight type define, refs #198
...
GEMMA_WEIGHT_T is indeed the correct flag for the C++ compiler,
but the readme references CMake, and there the correct flag name is WEIGHT_TYPE.
PiperOrigin-RevId: 641170380
2024-06-07 01:32:25 -07:00
Jan Wassenberg
8dc0e5ea83
Fix reference to GEMMA_WEIGHT_T. Refs #198
...
PiperOrigin-RevId: 641161403
2024-06-07 00:54:30 -07:00
Zoltan Szabadka
c004799cdc
Add Adam optimizer.
...
Drive-by: Fix compilation errors and tests for backprop functions.
2024-06-06 18:41:36 +00:00
Jan Wassenberg
12707ade80
Toward only using compressed weights:
...
CompressedLayer should all be f32 when weights are f32.
PiperOrigin-RevId: 640954519
2024-06-06 11:00:23 -07:00
Paul Chang
6c0be20fa6
Fix Softmax on SVE
...
PiperOrigin-RevId: 640947138
2024-06-06 10:39:30 -07:00
The gemma.cpp Authors
39d4115717
Implement mixed mode matmul: f32 * bf16
...
PiperOrigin-RevId: 640940962
2024-06-06 10:21:46 -07:00
Jan Wassenberg
57c2cd8b52
Simplifications: remove GemmaInterface and GemmaImpl
...
Split common and weights into separate lib
Remove common-inl (does not have to be SIMD code), activations.cc
Centralize switch(Model) to avoid duplication
Move CompressWeightsT to compress_weights.cc
Move LoadWeights to weights.cc
PiperOrigin-RevId: 640869202
2024-06-06 05:54:21 -07:00
Jan Wassenberg
5c3e5f7038
Remove no longer required stats.h - use Highway version instead
...
PiperOrigin-RevId: 640440379
2024-06-05 01:37:48 -07:00
Paul Chang
175e389c3c
revert back to HWY_ASSERT for lane constraints, qualify hn::Add
...
PiperOrigin-RevId: 640193239
2024-06-04 10:10:18 -07:00
Phil Culliton
e71d82ead9
Fix for GenerateZeroMat call in TestTiledMatMul
...
PiperOrigin-RevId: 640180868
2024-06-04 09:32:23 -07:00
Zelalem Aweke
9e213b3d96
Use system topology to pin threads across clusters.
...
PiperOrigin-RevId: 640151974
2024-06-04 07:50:32 -07:00
Jan Wassenberg
4f9155d8c6
Add bf16 matmul support, update naming+test
...
Avoid int32, which can easily overflow for large matrices.
Also fix IDE warning in sfp-inl.
PiperOrigin-RevId: 640149845
2024-06-04 07:41:46 -07:00
Copybara-Service
25d9c8ff30
Merge pull request #203 from szabadka:backprop5
...
PiperOrigin-RevId: 640133430
2024-06-04 06:33:08 -07:00
Zoltan Szabadka
be1d58d4fa
Fix bazel build
2024-06-04 11:13:19 +00:00
Zoltan Szabadka
cd41a4548e
Add missing include
2024-06-04 10:29:12 +00:00
Zoltan Szabadka
df01700b54
Move the backpropagation code to its own directory
2024-06-04 10:20:16 +00:00
Zoltan Szabadka
3b4fa4a0e3
Use HWY_EXPORT_AND_DYNAMIC_DISPATCH_T where possible.
2024-06-04 09:18:56 +00:00
Zoltan Szabadka
8567978541
Adress review comments
2024-06-04 08:37:54 +00:00
Zoltan Szabadka
7e639856da
Fix compilation and tests for gcc
2024-06-04 08:37:54 +00:00
Zoltan Szabadka
36e4d8bbfe
Add first version of backpropagation support.
...
This is still in progress / experimental, currently it is only
implemented for normal gemma MQA attention layers, and no
parallelism is added yet for backward pass.
Since we need to remember all activations from all layers, the
forward pass was also reimplemented with a new activation data
structure.
2024-06-04 08:37:49 +00:00
Paul Chang
ed8f39c058
Refactor GemmaImpl dispatch to use Highway 1.2's HWY_DYNAMIC_DISPATCH_T
...
PiperOrigin-RevId: 639793810
2024-06-03 08:32:29 -07:00
Jan Wassenberg
a44cbdadc2
Update to Highway 1.2 for topology/VQSelect
...
Also fix unused-warning in compress-inl.
PiperOrigin-RevId: 639116915
2024-05-31 12:29:10 -07:00
Paul Chang
5feacf120c
static_assert shape constraints in MatMul 4x4
...
PiperOrigin-RevId: 639069345
2024-05-31 10:02:45 -07:00
Phil Culliton
c616abe628
Unrolled / tiled 4x4 MatMul
...
PiperOrigin-RevId: 638384686
2024-05-29 13:02:35 -07:00
Paul Chang
419dc34ed5
Generic MHA/MQA/GQA implementation
...
PiperOrigin-RevId: 636937885
2024-05-24 09:05:53 -07:00
Copybara-Service
93c0088646
Merge pull request #194 from szabadka:softmax-fix
...
PiperOrigin-RevId: 636848144
2024-05-24 02:48:17 -07:00
Zoltan Szabadka
542ad0973a
Fix normalization in Softmax function.
2024-05-24 08:58:31 +00:00
Apoorv Reddy
1aaf3b3aae
Documenting the RoPE implementation.
...
PiperOrigin-RevId: 636175297
2024-05-22 08:26:29 -07:00
Paul Chang
c0643577c3
Minor internal refactoring.
...
PiperOrigin-RevId: 635852078
2024-05-21 10:29:59 -07:00
Copybara-Service
59a1f87d63
Merge pull request #189 from google:pculliton-kaggle-ci
...
PiperOrigin-RevId: 635811297
2024-05-21 08:13:43 -07:00
Apoorv Reddy
7f4b85d00b
Add MMLU eval to github
...
PiperOrigin-RevId: 635495178
2024-05-20 10:20:53 -07:00
pculliton
cf347dfe35
Adds Kaggle testing to CI workflow
...
Using a restricted Kaggle account, this code:
- Adds an Ubuntu 20.04 build (required for glibc compat with Kaggle infra)
- Uploads the ubuntu-20.04 build and supporting library to a Kaggle dataset using a fork of `push-kaggle-dataset`
- Creates a new version of a Kaggle notebook that loads artifacts from the Kaggle Model Hub, along with the newly updated dataset, and validates a 2b-it-sfp model.
- Runs the notebook and throws an error if the process does not complete, raises an exception, or produces an invalid response.
Todo: add tests / capabilities to the smoke tests used by the notebook.
2024-05-17 16:06:03 -04:00
Paul Chang
cfce314715
Make BlobWriter::Add() accept const void*
...
PiperOrigin-RevId: 634780483
2024-05-17 08:11:06 -07:00
Paul Chang
82623bdc7f
Refer to --weights rather than --compressed_weights to simplify CLI docs
...
PiperOrigin-RevId: 634391135
2024-05-16 07:51:49 -07:00
Apoorv Reddy
8e641eb4cd
Add TTFT to TimingInfo
...
PiperOrigin-RevId: 634378994
2024-05-16 07:16:53 -07:00
Apoorv Reddy
eb0b96e0a8
Pass most runtime parameters using const RuntimeConfig&
...
PiperOrigin-RevId: 633572507
2024-05-14 07:04:53 -07:00
Apoorv Reddy
f1eab987d8
Store tokens/sec in auxiliary struct TimingInfo.
...
PiperOrigin-RevId: 633108908
2024-05-13 00:04:19 -07:00
Jan Wassenberg
22fe9809ac
Fix SVE build: add missing hn::
...
PiperOrigin-RevId: 632481097
2024-05-10 06:49:26 -07:00
Jan Wassenberg
c5c9fc300c
Enable even/odd for SFP. Refs #166
...
Disable it for float32 because there is not enough benefit.
PiperOrigin-RevId: 631788326
2024-05-08 07:09:06 -07:00
Paul Chang
bacba351d4
Support additional scaling
...
PiperOrigin-RevId: 631429113
2024-05-07 08:16:25 -07:00
Jan Wassenberg
f6d02b2870
Fix RecurrentGemma (refs #166 ) - one Dot was ignoring scale.
...
Remove extra Dot() overload
MatVecAdd always adds, use MatVecT<kAdd> if conditional.
Remove ununsed MatVecAddLoop and MatVecLoop
No longer tsan-verify even_odd
PiperOrigin-RevId: 631377279
2024-05-07 04:40:42 -07:00
Jan Wassenberg
b5a9ade75f
2x speedup of SFP decode (1.4x overall) on AVX3_DL+.
...
Thanks @nzmichaelh for suggesting table lookups!
PiperOrigin-RevId: 631337524
2024-05-07 01:46:43 -07:00
Copybara-Service
18f6d43fcc
Merge pull request #169 from xinpingwang:cmake-install
...
PiperOrigin-RevId: 630425203
2024-05-03 10:16:46 -07:00
Wang Xinping
2c038e1285
work with cmake install
2024-05-03 23:44:12 +08:00
Copybara-Service
8ed22e52bf
Merge pull request #177 from szabadka:gemma2
...
PiperOrigin-RevId: 630388843
2024-05-03 07:52:27 -07:00
Zoltan Szabadka
19017fdb6d
Fix expression in DASSERT()
2024-05-03 13:54:20 +00:00
Phil Culliton
28ca001d5e
Matmul and test functions
...
PiperOrigin-RevId: 630373984
2024-05-03 06:39:36 -07:00
Zoltan Szabadka
429eb78512
Remove unused vars.
2024-05-03 13:37:17 +00:00