Commit Graph

434 Commits

Author SHA1 Message Date
Zoltan Szabadka 465998d25a Add support for custom sampling function to runtime config.
With this addition the ComputeCrossEntropy function can be moved
to its own library, because now we can compute it using only the
public API functions from gemma.h
2024-06-07 11:45:07 +00:00
Copybara-Service f7ac7092d6 Merge pull request #212 from szabadka:adam2
PiperOrigin-RevId: 641182573
2024-06-07 02:25:18 -07:00
Jan Wassenberg e3f4374e81 Fix fix for weight type define, refs #198
GEMMA_WEIGHT_T is indeed the correct flag for the C++ compiler,
but the readme references CMake, and there the correct flag name is WEIGHT_TYPE.

PiperOrigin-RevId: 641170380
2024-06-07 01:32:25 -07:00
Jan Wassenberg 8dc0e5ea83 Fix reference to GEMMA_WEIGHT_T. Refs #198
PiperOrigin-RevId: 641161403
2024-06-07 00:54:30 -07:00
Zoltan Szabadka c004799cdc Add Adam optimizer.
Drive-by: Fix compilation errors and tests for backprop functions.
2024-06-06 18:41:36 +00:00
Jan Wassenberg 12707ade80 Toward only using compressed weights:
CompressedLayer should all be f32 when weights are f32.

PiperOrigin-RevId: 640954519
2024-06-06 11:00:23 -07:00
Paul Chang 6c0be20fa6 Fix Softmax on SVE
PiperOrigin-RevId: 640947138
2024-06-06 10:39:30 -07:00
The gemma.cpp Authors 39d4115717 Implement mixed mode matmul: f32 * bf16
PiperOrigin-RevId: 640940962
2024-06-06 10:21:46 -07:00
Jan Wassenberg 57c2cd8b52 Simplifications: remove GemmaInterface and GemmaImpl
Split common and weights into separate lib
Remove common-inl (does not have to be SIMD code), activations.cc
Centralize switch(Model) to avoid duplication
Move CompressWeightsT to compress_weights.cc
Move LoadWeights to weights.cc

PiperOrigin-RevId: 640869202
2024-06-06 05:54:21 -07:00
Jan Wassenberg 5c3e5f7038 Remove no longer required stats.h - use Highway version instead
PiperOrigin-RevId: 640440379
2024-06-05 01:37:48 -07:00
Paul Chang 175e389c3c revert back to HWY_ASSERT for lane constraints, qualify hn::Add
PiperOrigin-RevId: 640193239
2024-06-04 10:10:18 -07:00
Phil Culliton e71d82ead9 Fix for GenerateZeroMat call in TestTiledMatMul
PiperOrigin-RevId: 640180868
2024-06-04 09:32:23 -07:00
Zelalem Aweke 9e213b3d96 Use system topology to pin threads across clusters.
PiperOrigin-RevId: 640151974
2024-06-04 07:50:32 -07:00
Jan Wassenberg 4f9155d8c6 Add bf16 matmul support, update naming+test
Avoid int32, which can easily overflow for large matrices.
Also fix IDE warning in sfp-inl.

PiperOrigin-RevId: 640149845
2024-06-04 07:41:46 -07:00
Copybara-Service 25d9c8ff30 Merge pull request #203 from szabadka:backprop5
PiperOrigin-RevId: 640133430
2024-06-04 06:33:08 -07:00
Zoltan Szabadka be1d58d4fa Fix bazel build 2024-06-04 11:13:19 +00:00
Zoltan Szabadka cd41a4548e Add missing include 2024-06-04 10:29:12 +00:00
Zoltan Szabadka df01700b54 Move the backpropagation code to its own directory 2024-06-04 10:20:16 +00:00
Zoltan Szabadka 3b4fa4a0e3 Use HWY_EXPORT_AND_DYNAMIC_DISPATCH_T where possible. 2024-06-04 09:18:56 +00:00
Zoltan Szabadka 8567978541 Adress review comments 2024-06-04 08:37:54 +00:00
Zoltan Szabadka 7e639856da Fix compilation and tests for gcc 2024-06-04 08:37:54 +00:00
Zoltan Szabadka 36e4d8bbfe Add first version of backpropagation support.
This is still in progress / experimental, currently it is only
implemented for normal gemma MQA attention layers, and no
parallelism is added yet for backward pass.

Since we need to remember all activations from all layers, the
forward pass was also reimplemented with a new activation data
structure.
2024-06-04 08:37:49 +00:00
Paul Chang ed8f39c058 Refactor GemmaImpl dispatch to use Highway 1.2's HWY_DYNAMIC_DISPATCH_T
PiperOrigin-RevId: 639793810
2024-06-03 08:32:29 -07:00
Jan Wassenberg a44cbdadc2 Update to Highway 1.2 for topology/VQSelect
Also fix unused-warning in compress-inl.

PiperOrigin-RevId: 639116915
2024-05-31 12:29:10 -07:00
Paul Chang 5feacf120c static_assert shape constraints in MatMul 4x4
PiperOrigin-RevId: 639069345
2024-05-31 10:02:45 -07:00
Phil Culliton c616abe628 Unrolled / tiled 4x4 MatMul
PiperOrigin-RevId: 638384686
2024-05-29 13:02:35 -07:00
Paul Chang 419dc34ed5 Generic MHA/MQA/GQA implementation
PiperOrigin-RevId: 636937885
2024-05-24 09:05:53 -07:00
Copybara-Service 93c0088646 Merge pull request #194 from szabadka:softmax-fix
PiperOrigin-RevId: 636848144
2024-05-24 02:48:17 -07:00
Zoltan Szabadka 542ad0973a Fix normalization in Softmax function. 2024-05-24 08:58:31 +00:00
Apoorv Reddy 1aaf3b3aae Documenting the RoPE implementation.
PiperOrigin-RevId: 636175297
2024-05-22 08:26:29 -07:00
Paul Chang c0643577c3 Minor internal refactoring.
PiperOrigin-RevId: 635852078
2024-05-21 10:29:59 -07:00
Copybara-Service 59a1f87d63 Merge pull request #189 from google:pculliton-kaggle-ci
PiperOrigin-RevId: 635811297
2024-05-21 08:13:43 -07:00
Apoorv Reddy 7f4b85d00b Add MMLU eval to github
PiperOrigin-RevId: 635495178
2024-05-20 10:20:53 -07:00
pculliton cf347dfe35
Adds Kaggle testing to CI workflow
Using a restricted Kaggle account, this code:
- Adds an Ubuntu 20.04 build (required for glibc compat with Kaggle infra)
- Uploads the ubuntu-20.04 build and supporting library to a Kaggle dataset using a fork of `push-kaggle-dataset`
- Creates a new version of a Kaggle notebook that loads artifacts from the Kaggle Model Hub, along with the newly updated dataset, and validates a 2b-it-sfp model.
- Runs the notebook and throws an error if the process does not complete, raises an exception, or produces an invalid response.

Todo: add tests / capabilities to the smoke tests used by the notebook.
2024-05-17 16:06:03 -04:00
Paul Chang cfce314715 Make BlobWriter::Add() accept const void*
PiperOrigin-RevId: 634780483
2024-05-17 08:11:06 -07:00
Paul Chang 82623bdc7f Refer to --weights rather than --compressed_weights to simplify CLI docs
PiperOrigin-RevId: 634391135
2024-05-16 07:51:49 -07:00
Apoorv Reddy 8e641eb4cd Add TTFT to TimingInfo
PiperOrigin-RevId: 634378994
2024-05-16 07:16:53 -07:00
Apoorv Reddy eb0b96e0a8 Pass most runtime parameters using const RuntimeConfig&
PiperOrigin-RevId: 633572507
2024-05-14 07:04:53 -07:00
Apoorv Reddy f1eab987d8 Store tokens/sec in auxiliary struct TimingInfo.
PiperOrigin-RevId: 633108908
2024-05-13 00:04:19 -07:00
Jan Wassenberg 22fe9809ac Fix SVE build: add missing hn::
PiperOrigin-RevId: 632481097
2024-05-10 06:49:26 -07:00
Jan Wassenberg c5c9fc300c Enable even/odd for SFP. Refs #166
Disable it for float32 because there is not enough benefit.

PiperOrigin-RevId: 631788326
2024-05-08 07:09:06 -07:00
Paul Chang bacba351d4 Support additional scaling
PiperOrigin-RevId: 631429113
2024-05-07 08:16:25 -07:00
Jan Wassenberg f6d02b2870 Fix RecurrentGemma (refs #166) - one Dot was ignoring scale.
Remove extra Dot() overload
MatVecAdd always adds, use MatVecT<kAdd> if conditional.
Remove ununsed MatVecAddLoop and MatVecLoop
No longer tsan-verify even_odd

PiperOrigin-RevId: 631377279
2024-05-07 04:40:42 -07:00
Jan Wassenberg b5a9ade75f 2x speedup of SFP decode (1.4x overall) on AVX3_DL+.
Thanks @nzmichaelh for suggesting table lookups!

PiperOrigin-RevId: 631337524
2024-05-07 01:46:43 -07:00
Copybara-Service 18f6d43fcc Merge pull request #169 from xinpingwang:cmake-install
PiperOrigin-RevId: 630425203
2024-05-03 10:16:46 -07:00
Wang Xinping 2c038e1285 work with cmake install 2024-05-03 23:44:12 +08:00
Copybara-Service 8ed22e52bf Merge pull request #177 from szabadka:gemma2
PiperOrigin-RevId: 630388843
2024-05-03 07:52:27 -07:00
Zoltan Szabadka 19017fdb6d Fix expression in DASSERT() 2024-05-03 13:54:20 +00:00
Phil Culliton 28ca001d5e Matmul and test functions
PiperOrigin-RevId: 630373984
2024-05-03 06:39:36 -07:00
Zoltan Szabadka 429eb78512 Remove unused vars. 2024-05-03 13:37:17 +00:00