Merge branch 'google:main' into main

This commit is contained in:
Sascha Ronnie Daoudia 2024-03-10 11:59:40 +01:00 committed by GitHub
commit 0e1aefdac1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 2934 additions and 159 deletions

1
.bazelrc Normal file
View File

@ -0,0 +1 @@
common --enable_bzlmod

View File

@ -1 +1,2 @@
Language: Cpp
BasedOnStyle: Google

206
.clang-tidy Normal file
View File

@ -0,0 +1,206 @@
FormatStyle: file
Checks: "-*,\
abseil-*,\
-abseil-string-find-startswith,\
-abseil-string-find-str-contains,\
bugprone-*,\
-bugprone-argument-comment,\
-bugprone-assert-side-effect,\
-bugprone-bad-signal-to-kill-thread,\
-bugprone-bool-pointer-implicit-conversion,\
-bugprone-branch-clone,\
-bugprone-copy-constructor-init,\
-bugprone-dangling-handle,\
-bugprone-dynamic-static-initializers,\
-bugprone-easily-swappable-parameters,\
-bugprone-exception-escape,\
-bugprone-fold-init-type,\
-bugprone-forward-declaration-namespace,\
-bugprone-forwarding-reference-overload,\
-bugprone-implicit-widening-of-multiplication-result,\
-bugprone-inaccurate-erase,\
-bugprone-incorrect-roundings,\
-bugprone-infinite-loop,\
-bugprone-integer-division,\
-bugprone-lambda-function-name,\
-bugprone-macro-parentheses,\
-bugprone-macro-repeated-side-effects,\
-bugprone-misplaced-operator-in-strlen-in-alloc,\
-bugprone-misplaced-widening-cast,\
-bugprone-move-forwarding-reference,\
-bugprone-multiple-statement-macro,\
-bugprone-narrowing-conversions,\
-bugprone-no-escape,\
-bugprone-not-null-terminated-result,\
-bugprone-parent-virtual-call,\
-bugprone-posix-return,\
-bugprone-redundant-branch-condition,\
-bugprone-reserved-identifier,\
-bugprone-signal-handler,\
-bugprone-signed-char-misuse,\
-bugprone-sizeof-container,\
-bugprone-sizeof-expression,\
-bugprone-spuriously-wake-up-functions,\
-bugprone-string-constructor,\
-bugprone-string-integer-assignment,\
-bugprone-string-literal-with-embedded-nul,\
-bugprone-stringview-nullptr,\
-bugprone-suspicious-enum-usage,\
-bugprone-suspicious-include,\
-bugprone-suspicious-memory-comparison,\
-bugprone-suspicious-memset-usage,\
-bugprone-suspicious-missing-comma,\
-bugprone-suspicious-semicolon,\
-bugprone-suspicious-string-compare,\
-bugprone-swapped-arguments,\
-bugprone-terminating-continue,\
-bugprone-throw-keyword-missing,\
-bugprone-too-small-loop-variable,\
-bugprone-undefined-memory-manipulation,\
-bugprone-undelegated-constructor,\
-bugprone-unhandled-exception-at-new,\
-bugprone-unhandled-self-assignment,\
-bugprone-unused-raii,\
-bugprone-unused-return-value,\
-bugprone-use-after-move,\
-bugprone-virtual-near-miss,\
cert-*,\
-cert-dcl16-c,\
-cert-dcl21-cpp,\
-cert-dcl37-c,\
-cert-dcl50-cpp,\
-cert-dcl51-cpp,\
-cert-dcl54-cpp,\
-cert-dcl58-cpp,\
-cert-err33-c,\
-cert-msc30-c,\
-cert-msc32-c,\
-cert-msc50-cpp,\
-cert-msc51-cpp,\
-cert-oop54-cpp,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-clang-analyzer-*,\
concurrency-*,\
-concurrency-mt-unsafe,\
cppcoreguidelines-*,\
-concurrency-mt-unsafe,\
-cppcoreguidelines-avoid-c-arrays,\
-cppcoreguidelines-avoid-const-or-ref-data-members,\
-cppcoreguidelines-avoid-goto,\
-cppcoreguidelines-avoid-magic-numbers,\
-cppcoreguidelines-avoid-non-const-global-variables,\
-cppcoreguidelines-c-copy-assignment-signature,\
-cppcoreguidelines-explicit-virtual-functions,\
-cppcoreguidelines-init-variables,\
-cppcoreguidelines-interfaces-global-init,\
-cppcoreguidelines-macro-usage,\
-cppcoreguidelines-narrowing-conversions,\
-cppcoreguidelines-no-malloc,\
-cppcoreguidelines-non-private-member-variables-in-classes,\
-cppcoreguidelines-owning-memory,\
-cppcoreguidelines-prefer-member-initializer,\
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,\
-cppcoreguidelines-pro-bounds-constant-array-index,\
-cppcoreguidelines-pro-bounds-pointer-arithmetic,\
-cppcoreguidelines-pro-type-const-cast,\
-cppcoreguidelines-pro-type-member-init,\
-cppcoreguidelines-pro-type-reinterpret-cast,\
-cppcoreguidelines-pro-type-static-cast-downcast,\
-cppcoreguidelines-pro-type-union-access,\
-cppcoreguidelines-pro-type-vararg,\
-cppcoreguidelines-slicing,\
-cppcoreguidelines-special-member-functions,\
-cppcoreguidelines-virtual-class-destructor,\
google-*,\
-google-default-arguments,\
-google-explicit-constructor,\
-google-readability-avoid-underscore-in-googletest-name,\
-google-readability-braces-around-statements,\
-google-readability-casting,\
-google-readability-namespace-comments,\
-google-readability-todo,\
-google-runtime-int,\
-google-upgrade-googletest-case,\
misc-*,\
-misc-misplaced-const,\
-misc-new-delete-overloads,\
-misc-non-private-member-variables-in-classes,\
-misc-no-recursion,\
-misc-redundant-expression,\
-misc-uniqueptr-reset-release,\
-misc-unconventional-assign-operator,\
-misc-unused-parameters,\
-misc-unused-using-decls,\
modernize-*,\
-modernize-avoid-c-arrays,\
-modernize-concat-nested-namespaces,\
-modernize-deprecated-headers,\
-modernize-loop-convert,\
-modernize-macro-to-enum,\
-modernize-make-unique,\
-modernize-pass-by-value,\
-modernize-raw-string-literal,\
-modernize-redundant-void-arg,\
-modernize-return-braced-init-list,\
-modernize-unary-static-assert,\
-modernize-use-auto,\
-modernize-use-bool-literals,\
-modernize-use-default-member-init,\
-modernize-use-emplace,\
-modernize-use-equals-default,\
-modernize-use-equals-delete,\
-modernize-use-nodiscard,\
-modernize-use-nullptr,\
-modernize-use-override,\
-modernize-use-trailing-return-type,\
-modernize-use-transparent-functors,\
-modernize-use-using,\
performance-*,\
-performance-faster-string-find,\
-performance-for-range-copy,\
-performance-inefficient-algorithm,\
-performance-inefficient-string-concatenation,\
-performance-inefficient-vector-operation,\
-performance-move-const-arg,\
-performance-no-automatic-move,\
-performance-noexcept-move-constructor,\
-performance-no-int-to-ptr,\
-performance-trivially-destructible,\
-performance-unnecessary-copy-initialization,\
-performance-unnecessary-value-param,\
portability-*,\
readability-*,\
-readability-avoid-const-params-in-decls,\
-readability-braces-around-statements,\
-readability-const-return-type,\
-readability-container-data-pointer,\
-readability-container-size-empty,\
-readability-convert-member-functions-to-static,\
-readability-else-after-return,\
-readability-function-cognitive-complexity,\
-readability-identifier-length,\
-readability-implicit-bool-conversion,\
-readability-inconsistent-declaration-parameter-name,\
-readability-isolate-declaration,\
-readability-magic-numbers,\
-readability-make-member-function-const,\
-readability-named-parameter,\
-readability-non-const-parameter,\
-readability-qualified-auto,\
-readability-redundant-access-specifiers,\
-readability-redundant-control-flow,\
-readability-redundant-declaration,\
-readability-redundant-member-init,\
-readability-redundant-smartptr-get,\
-readability-redundant-string-cstr,\
-readability-redundant-string-init,\
-readability-simplify-boolean-expr,\
-readability-static-accessed-through-instance,\
-readability-static-definition-in-anonymous-namespace,\
-readability-suspicious-call-argument,\
-readability-uppercase-literal-suffix,\
-readability-use-anyofallof
"

View File

@ -12,6 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
# When adding another, also add to copybara's github_check_runs.
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
build_type: ['Release']
preset: ['make', 'windows']
@ -43,7 +44,7 @@ jobs:
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
- name: Build
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }}
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4
- name: Archive production artifacts
uses: actions/upload-artifact@v4
@ -54,3 +55,21 @@ jobs:
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
${{ github.workspace }}/build/gemma
${{ github.workspace }}/build/libgemma.a
bazel:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0
with:
egress-policy: audit # cannot be block - runner does git checkout
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.0.0
- uses: bazelbuild/setup-bazelisk@b39c379c82683a5f25d34f0d062761f62693e0b2 # v3.0.0
- uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # v4.0.1
with:
path: ~/.cache/bazel
key: bazel-${{ runner.os }}
- run: bazel build -c opt --cxxopt=-std=c++20 //...

View File

@ -25,21 +25,14 @@ cc_library(
],
deps = [
"//compression:compress",
# copybara:import_next_line:hwy
"//:algo",
# copybara:import_next_line:hwy
"//:dot",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:math",
# copybara:import_next_line:hwy
"//:matvec",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
"//hwy/contrib/sort:vqsort",
"@hwy//:algo",
"@hwy//:dot",
"@hwy//:hwy",
"@hwy//:math",
"@hwy//:matvec",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@hwy//hwy/contrib/sort:vqsort",
],
)
@ -49,8 +42,7 @@ cc_library(
"util/args.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -61,8 +53,7 @@ cc_library(
],
deps = [
":args",
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -78,19 +69,13 @@ cc_library(
deps = [
":args",
":transformer_ops",
"//base",
"//compression:compress",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:matvec",
# copybara:import_next_line:hwy
"//:nanobenchmark", # timer
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
":sentencepiece_processor",
"@hwy//:hwy",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@hwy//:thread_pool",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
@ -104,13 +89,9 @@ cc_binary(
":args",
":gemma_lib",
"//compression:compress",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:nanobenchmark",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
],
)

View File

@ -20,6 +20,7 @@ project(gemma)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)

View File

@ -100,3 +100,79 @@ be exposed to the build system):
In the medium term both of these will likely be deprecated in favor of handling
options at runtime - allowing for multiple weight compression schemes in a single
build and dynamically resizes the KV cache as needed.
## Using gemma.cpp as a Library (Advanced)
Unless you are doing lower level implementations or research, from an
application standpoint you can think of gemma.h and gemma.cc as the "core" of
the library.
You can regard `run.cc` as an example application that your own application is
substituting for, so the invocations into gemma.h and gemma.cc you see in
`run.cc` are probably the functions you'll be invoking. You can find examples
of the invocations to tokenizer methods and `GenerateGemma` in `run.cc`.
Keep in mind gemma.cpp is oriented at more experimental / prototype / research
applications. If you're targeting production, there's more standard paths via
jax / pytorch / keras for NN deployments.
### Gemma struct contains all the state of the inference engine - tokenizer, weights, and activations
`Gemma(...)` - constructor, creates a gemma model object, which is a wrapper
around 3 things - the tokenizer object, weights, activations, and KV Cache.
In a standard LLM chat app, you'll probably use a Gemma object directly, in
more exotic data processing or research applications, you might decompose
working with weights, kv cache and activations (e.g. you might have multiple kv
caches and activations for a single set of weights) more directly rather than
only using a Gemma object.
### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
You pretty much only do things with the tokenizer, call `Encode()` to go from
string prompts to token id vectors, or `Decode()` to go from token id vector
outputs from the model back to strings.
### The main entrypoint for generation is `GenerateGemma()`
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
activation values in `model` and 2) invoke StreamFunc - a lambda callback for
each generated token.
Your application defines its own StreamFunc as a lambda callback to do
something everytime a token string is streamed from the engine (eg print to the
screen, write data to the disk, send the string to a server, etc.). You can see
in `run.cc` the StreamFunc lambda takes care of printing each token to the
screen as it arrives.
Optionally you can define accept_token as another lambda - this is mostly for
constrained decoding type of use cases where you want to force the generation
to fit a grammar. If you're not doing this, you can send an empty lambda as a
no-op which is what `run.cc` does.
### If you want to invoke the neural network forward function directly call the `Transformer()` function
For high-level applications, you might only call `GenerateGemma()` and never
interact directly with the neural network, but if you're doing something a bit
more custom you can call transformer which performs a single inference
operation on a single token and mutates the Activations and the KVCache through
the neural network computation.
### For low level operations, defining new architectures, call `ops.h` functions directly
You use `ops.h` if you're writing other NN architectures or modifying the
inference path of the Gemma model.
## Building with Bazel
The sentencepiece library we depend on requires some additional work to build
with the Bazel build system. First, it does not export its BUILD file, so we
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
support Abseil as a standalone dependency without third_party/ prefixes, similar
to the transforms we apply to Gemma via Copybara.
## Discord
We're also trying out a discord server for discussion here -
https://discord.gg/H5jCBAWxAe

View File

@ -3,12 +3,57 @@ module(
version = "0.1.0",
)
bazel_dep(
name = "rules_license",
version = "0.0.7",
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "googletest", version = "1.14.0")
# Copied from Highway because Bazel does not load them transitively
bazel_dep(name = "bazel_skylib", version = "1.4.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "platforms", version = "0.0.7")
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "hwy",
urls = ["https://github.com/google/highway/archive/refs/tags/1.1.0.zip"],
integrity = "sha256-zkJX2SwL4wQ0nHMsURW7MDLEf43vFSnqhSUsUM6eQmY=",
strip_prefix = "highway-1.1.0",
)
bazel_dep(
http_archive(
name = "com_google_sentencepiece",
version = "0.1.96",
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:com_google_sentencepiece.patch"],
patch_args = ["-p1"],
)
# For sentencepiece
http_archive(
name = "darts_clone",
build_file_content = """
licenses(["notice"])
exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "darts_clone",
hdrs = [
"include/darts.h",
],
)
""",
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
urls = [
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)
# ABSL on 2023-10-18
http_archive(
name = "com_google_absl",
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
)

View File

@ -65,15 +65,26 @@ winget install --id Kitware.CMake
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
```
### Step 1: Obtain model weights and tokenizer from Kaggle
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
Visit [the Gemma model page on
Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations
Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp) and select `Model Variations
|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below.
Note bfloat16 weights are higher fidelity, while 8-bit switched floating point
weights enable faster inference. In general, we recommend starting with the
`-sfp` checkpoints.
Alternatively, visit the [gemma.cpp](https://huggingface.co/models?other=gemma.cpp)
models on the Hugging Face Hub. First go the the model repository of the model of interest
(see recommendations below). Then, click the `Files and versions` tab and download the
model and tokenizer files. For programmatic downloading, if you have `huggingface_hub`
installed, you can also download by running:
```
huggingface-cli login # Just the first time
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
```
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
@ -98,6 +109,8 @@ weights enable faster inference. In general, we recommend starting with the
### Step 2: Extract Files
If you downloaded the models from Hugging Face, skip to step 3.
After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes):
@ -175,6 +188,14 @@ cmake --build --preset windows -j [number of parallel threads to use]
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
#### Bazel
```sh
bazel build -c opt --cxxopt=-std=c++20 :gemma
```
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.

View File

@ -1,24 +1,4 @@
workspace(name = "gemma")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
maybe(
http_archive,
name = "rules_license",
sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360",
urls = [
"https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz",
],
)
maybe(
http_archive,
name = "com_google_sentencepiece",
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//third_party:sentencepiece.bazel",
patches = ["@//third_party:com_google_sentencepiece.patch"],
patch_args = ["-p1"],
)
# This file marks the root of the Bazel workspace.
# See MODULE.bazel for external dependencies setup.

4
bazel/BUILD Normal file
View File

@ -0,0 +1,4 @@
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"],
)

File diff suppressed because it is too large Load Diff

97
bazel/sentencepiece.bazel Normal file
View File

@ -0,0 +1,97 @@
package(
default_visibility = ["//visibility:public"],
features = [
"layering_check",
"parse_headers",
],
)
licenses(["notice"]) # Apache 2, BSD, MIT
proto_library(
name = "sentencepiece_proto",
srcs = ["src/sentencepiece.proto"],
)
cc_proto_library(
name = "sentencepiece_cc_proto",
deps = [":sentencepiece_proto"],
)
proto_library(
name = "sentencepiece_model_proto",
srcs = ["src/sentencepiece_model.proto"],
)
cc_proto_library(
name = "sentencepiece_model_cc_proto",
deps = [":sentencepiece_model_proto"],
)
genrule(
name = "config_h",
srcs = ["config.h.in"],
outs = ["config.h"],
cmd = "cp $< $@",
)
cc_library(
name = "common",
hdrs = [
"config.h",
"src/common.h",
],
deps = [
"@com_google_absl//absl/base",
],
)
cc_library(
name = "sentencepiece_processor",
srcs = [
"src/bpe_model.cc",
"src/char_model.cc",
"src/error.cc",
"src/filesystem.cc",
"src/model_factory.cc",
"src/model_interface.cc",
"src/normalizer.cc",
"src/sentencepiece_processor.cc",
"src/unigram_model.cc",
"src/util.cc",
"src/word_model.cc",
],
hdrs = [
"src/bpe_model.h",
"src/char_model.h",
"src/filesystem.h",
"src/freelist.h",
"src/model_factory.h",
"src/model_interface.h",
"src/normalizer.h",
"src/sentencepiece_processor.h",
"src/trainer_interface.h",
"src/unigram_model.h",
"src/util.h",
"src/word_model.h",
],
defines = ["_USE_TF_STRING_VIEW"],
includes = [
".",
"src",
],
linkstatic = 1,
deps =
[
":common",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@darts_clone",
],
)

View File

@ -1,10 +1,10 @@
# Weight compression, I/O and analysis
package(
default_applicable_licenses = ["//third_party/gemma_cpp:license"],
default_applicable_licenses = ["//:license"],
default_visibility = [
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__",
"//third_party/gemma_cpp:__subpackages__",
"//:__subpackages__",
],
)
@ -17,10 +17,8 @@ cc_library(
"blob_store.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:thread_pool",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)
@ -34,8 +32,7 @@ cc_library(
"stats.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -48,8 +45,7 @@ cc_library(
"sfp-inl.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -65,15 +61,11 @@ cc_test(
deps = [
":sfp",
":stats",
"//testing/base/public:gunit_main_no_google3",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:hwy_test_util",
# copybara:import_next_line:hwy
"//:nanobenchmark",
# copybara:import_next_line:hwy
"//:thread_pool",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
@ -87,9 +79,8 @@ cc_library(
],
deps = [
":sfp",
# copybara:import_next_line:hwy
"//:hwy",
"//third_party/highway/hwy/contrib/sort:vqsort",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],
)
@ -106,13 +97,10 @@ cc_test(
":nuq",
":sfp",
":stats",
"//testing/base/public:gunit_main_no_google3",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:hwy_test_util",
# copybara:import_next_line:hwy
"//:nanobenchmark",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
],
)
@ -131,12 +119,9 @@ cc_library(
":nuq",
":sfp",
":stats",
# copybara:import_next_line:hwy
"//:dot",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:thread_pool",
"@hwy//:dot",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)
@ -150,12 +135,9 @@ cc_library(
":nuq",
":sfp",
":stats",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:nanobenchmark", # timer
# copybara:import_next_line:hwy
"//:thread_pool",
"//third_party/highway/hwy/contrib/sort:vqsort",
"@hwy//:hwy",
"@hwy//:nanobenchmark", # timer
"@hwy//:thread_pool",
"@hwy//hwy/contrib/sort:vqsort",
],
)

View File

@ -149,7 +149,7 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
size_t remaining = num - i;
const size_t remaining = num - i;
if (remaining != 0) {
const VF in0 = hn::LoadN(df, in + i, remaining);
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
@ -195,7 +195,7 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
size_t remaining = num - i;
const size_t remaining = num - i;
if (remaining != 0) {
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
const VF in0 = hn::PromoteLowerTo(df, in16);
@ -287,7 +287,7 @@ struct CompressTraits<NuqStream> {
if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(in[i] * 100 + 500);
tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
}
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
});
const double t1 = hwy::platform::Now();
const double mb = num * sizeof(in[0]) * 1E-6;
const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
const double mbps = mb / (t1 - t0);
fprintf(stderr, "Compress %.1f MB/s\n", mbps);

View File

@ -68,15 +68,15 @@ class DistortionStats {
double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0;
return exp(sum_log_rel_ / num_rel_);
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
}
double PNorm() const {
// p-norms are a compromise between max-norm (penalizes the largest error
// without dilution, but does not notice any other errors) and L1 (all
// errors contribute, but large errors are diluted by smaller ones).
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
return 0.5 * (norm3 + norm6);
}

View File

@ -87,7 +87,7 @@ class NuqClustering {
inv_len_[0] = 0.0f; // unused
for (size_t i = 0; i <= kGroupSize; ++i) {
inv_len_[i] = 1.0f / i;
inv_len_[i] = 1.0f / static_cast<float>(i);
}
}
@ -229,7 +229,7 @@ class NuqClustering {
const float sum = cc.SumOfSorted(start, last);
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
HWY_DASSERT(0 < size && size <= kGroupSize);
centers[k] = sum / size;
centers[k] = sum / static_cast<float>(size);
// We know the range inside sorted_and_i[]; translate to original indices,
// which are stored inside each of the sorted_and_i mantissas.
@ -470,9 +470,7 @@ class NuqCodec {
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
ClusterBuf& buf, const size_t out_capacity,
NuqStream* const out, const size_t out_ofs) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<uint16_t, DF> d16;
using V8 = hn::Vec<decltype(d8)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t N16 = hn::Lanes(d16);

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
@ -23,9 +28,10 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT
#define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
@ -27,9 +32,10 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
@ -301,7 +307,7 @@ struct TestEncDec {
for (size_t i = 0; i < num; ++i) {
const float out = hwy::F32FromBF16(dec[i]);
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
stats.Notify(in[i], out);
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
}
const double avg = sum / num;
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",

View File

@ -525,7 +525,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos
double prefill_start = hwy::platform::Now();
const double prefill_start = hwy::platform::Now();
// Prefill stops before prompt.size() - 1 since the last prompt token is the
// first input token for generation.
@ -547,12 +547,12 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]\n";
}
double gen_start = hwy::platform::Now();
const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt.size() - 1);
@ -590,9 +590,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, const InferenceArgs& args,
}
if (token == EOS_ID) {
if (verbosity >= 2) {
double gen_end = hwy::platform::Now();
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
(pos_offset - pos_gen_start) / (gen_end - gen_start);
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
@ -689,7 +689,7 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
if (loader.ReadAll(pool)) return c_weights_u8;
// Get weights, compress, and store in cache.
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
const hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str());
@ -721,10 +721,10 @@ HWY_EXPORT(GetCompressedWeightsT);
HWY_EXPORT(Generate2B);
HWY_EXPORT(Generate7B);
KVCache CreateKVCache(size_t size_cache_pos, size_t kSeqLen) {
KVCache CreateKVCache(size_t size_cache_pos, size_t seq_len) {
KVCache kv_cache = {};
kv_cache.key_cache = hwy::AllocateAligned<float>(kSeqLen * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(kSeqLen * size_cache_pos);
kv_cache.key_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
kv_cache.value_cache = hwy::AllocateAligned<float>(seq_len * size_cache_pos);
return kv_cache;
}

56
ops.h
View File

@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() {
return 2048;
}
template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
StaticCast(From from) noexcept {
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
return static_cast<To>(
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
else
return static_cast<To>(from);
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
@ -230,7 +241,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
size_t i = 0;
if (size >= 2 * NF) {
for (; i < size - 2 * NF; i += 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
const VF mul0 = hn::LoadU(df, mul + i);
const VF mul1 = hn::LoadU(df, mul + i + NF);
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]);
@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
const size_t num_timescales = dim_model / 2;
const float log_timescale_increment =
logf(10000.0f) /
(num_timescales != 0
? static_cast<float>(static_cast<int>(num_timescales) - 1)
: 1.0f);
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
for (size_t dim = 0; dim < num_timescales; ++dim) {
const float inv_timescale =
expf(static_cast<int>(dim) * -log_timescale_increment);
x[dim] += sinf(pos * inv_timescale);
x[num_timescales + dim] += cosf(pos * inv_timescale);
expf(StaticCast<float>(dim) * -log_timescale_increment);
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
}
}
@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = probabilities[i];
indices[j] = static_cast<int>(i);
indices[j] = StaticCast<int>(i);
break;
}
}