[WIP] dev/examples branch merge

This commit is contained in:
austinvhuang 2024-03-06 15:10:48 -05:00
commit 5b9d8a9936
21 changed files with 2872 additions and 161 deletions

1
.bazelrc Normal file
View File

@ -0,0 +1 @@
common --enable_bzlmod

View File

@ -1 +1,2 @@
Language: Cpp
BasedOnStyle: Google

206
.clang-tidy Normal file
View File

@ -0,0 +1,206 @@
FormatStyle: file
Checks: "-*,\
abseil-*,\
-abseil-string-find-startswith,\
-abseil-string-find-str-contains,\
bugprone-*,\
-bugprone-argument-comment,\
-bugprone-assert-side-effect,\
-bugprone-bad-signal-to-kill-thread,\
-bugprone-bool-pointer-implicit-conversion,\
-bugprone-branch-clone,\
-bugprone-copy-constructor-init,\
-bugprone-dangling-handle,\
-bugprone-dynamic-static-initializers,\
-bugprone-easily-swappable-parameters,\
-bugprone-exception-escape,\
-bugprone-fold-init-type,\
-bugprone-forward-declaration-namespace,\
-bugprone-forwarding-reference-overload,\
-bugprone-implicit-widening-of-multiplication-result,\
-bugprone-inaccurate-erase,\
-bugprone-incorrect-roundings,\
-bugprone-infinite-loop,\
-bugprone-integer-division,\
-bugprone-lambda-function-name,\
-bugprone-macro-parentheses,\
-bugprone-macro-repeated-side-effects,\
-bugprone-misplaced-operator-in-strlen-in-alloc,\
-bugprone-misplaced-widening-cast,\
-bugprone-move-forwarding-reference,\
-bugprone-multiple-statement-macro,\
-bugprone-narrowing-conversions,\
-bugprone-no-escape,\
-bugprone-not-null-terminated-result,\
-bugprone-parent-virtual-call,\
-bugprone-posix-return,\
-bugprone-redundant-branch-condition,\
-bugprone-reserved-identifier,\
-bugprone-signal-handler,\
-bugprone-signed-char-misuse,\
-bugprone-sizeof-container,\
-bugprone-sizeof-expression,\
-bugprone-spuriously-wake-up-functions,\
-bugprone-string-constructor,\
-bugprone-string-integer-assignment,\
-bugprone-string-literal-with-embedded-nul,\
-bugprone-stringview-nullptr,\
-bugprone-suspicious-enum-usage,\
-bugprone-suspicious-include,\
-bugprone-suspicious-memory-comparison,\
-bugprone-suspicious-memset-usage,\
-bugprone-suspicious-missing-comma,\
-bugprone-suspicious-semicolon,\
-bugprone-suspicious-string-compare,\
-bugprone-swapped-arguments,\
-bugprone-terminating-continue,\
-bugprone-throw-keyword-missing,\
-bugprone-too-small-loop-variable,\
-bugprone-undefined-memory-manipulation,\
-bugprone-undelegated-constructor,\
-bugprone-unhandled-exception-at-new,\
-bugprone-unhandled-self-assignment,\
-bugprone-unused-raii,\
-bugprone-unused-return-value,\
-bugprone-use-after-move,\
-bugprone-virtual-near-miss,\
cert-*,\
-cert-dcl16-c,\
-cert-dcl21-cpp,\
-cert-dcl37-c,\
-cert-dcl50-cpp,\
-cert-dcl51-cpp,\
-cert-dcl54-cpp,\
-cert-dcl58-cpp,\
-cert-err33-c,\
-cert-msc30-c,\
-cert-msc32-c,\
-cert-msc50-cpp,\
-cert-msc51-cpp,\
-cert-oop54-cpp,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-clang-analyzer-*,\
concurrency-*,\
-concurrency-mt-unsafe,\
cppcoreguidelines-*,\
-concurrency-mt-unsafe,\
-cppcoreguidelines-avoid-c-arrays,\
-cppcoreguidelines-avoid-const-or-ref-data-members,\
-cppcoreguidelines-avoid-goto,\
-cppcoreguidelines-avoid-magic-numbers,\
-cppcoreguidelines-avoid-non-const-global-variables,\
-cppcoreguidelines-c-copy-assignment-signature,\
-cppcoreguidelines-explicit-virtual-functions,\
-cppcoreguidelines-init-variables,\
-cppcoreguidelines-interfaces-global-init,\
-cppcoreguidelines-macro-usage,\
-cppcoreguidelines-narrowing-conversions,\
-cppcoreguidelines-no-malloc,\
-cppcoreguidelines-non-private-member-variables-in-classes,\
-cppcoreguidelines-owning-memory,\
-cppcoreguidelines-prefer-member-initializer,\
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,\
-cppcoreguidelines-pro-bounds-constant-array-index,\
-cppcoreguidelines-pro-bounds-pointer-arithmetic,\
-cppcoreguidelines-pro-type-const-cast,\
-cppcoreguidelines-pro-type-member-init,\
-cppcoreguidelines-pro-type-reinterpret-cast,\
-cppcoreguidelines-pro-type-static-cast-downcast,\
-cppcoreguidelines-pro-type-union-access,\
-cppcoreguidelines-pro-type-vararg,\
-cppcoreguidelines-slicing,\
-cppcoreguidelines-special-member-functions,\
-cppcoreguidelines-virtual-class-destructor,\
google-*,\
-google-default-arguments,\
-google-explicit-constructor,\
-google-readability-avoid-underscore-in-googletest-name,\
-google-readability-braces-around-statements,\
-google-readability-casting,\
-google-readability-namespace-comments,\
-google-readability-todo,\
-google-runtime-int,\
-google-upgrade-googletest-case,\
misc-*,\
-misc-misplaced-const,\
-misc-new-delete-overloads,\
-misc-non-private-member-variables-in-classes,\
-misc-no-recursion,\
-misc-redundant-expression,\
-misc-uniqueptr-reset-release,\
-misc-unconventional-assign-operator,\
-misc-unused-parameters,\
-misc-unused-using-decls,\
modernize-*,\
-modernize-avoid-c-arrays,\
-modernize-concat-nested-namespaces,\
-modernize-deprecated-headers,\
-modernize-loop-convert,\
-modernize-macro-to-enum,\
-modernize-make-unique,\
-modernize-pass-by-value,\
-modernize-raw-string-literal,\
-modernize-redundant-void-arg,\
-modernize-return-braced-init-list,\
-modernize-unary-static-assert,\
-modernize-use-auto,\
-modernize-use-bool-literals,\
-modernize-use-default-member-init,\
-modernize-use-emplace,\
-modernize-use-equals-default,\
-modernize-use-equals-delete,\
-modernize-use-nodiscard,\
-modernize-use-nullptr,\
-modernize-use-override,\
-modernize-use-trailing-return-type,\
-modernize-use-transparent-functors,\
-modernize-use-using,\
performance-*,\
-performance-faster-string-find,\
-performance-for-range-copy,\
-performance-inefficient-algorithm,\
-performance-inefficient-string-concatenation,\
-performance-inefficient-vector-operation,\
-performance-move-const-arg,\
-performance-no-automatic-move,\
-performance-noexcept-move-constructor,\
-performance-no-int-to-ptr,\
-performance-trivially-destructible,\
-performance-unnecessary-copy-initialization,\
-performance-unnecessary-value-param,\
portability-*,\
readability-*,\
-readability-avoid-const-params-in-decls,\
-readability-braces-around-statements,\
-readability-const-return-type,\
-readability-container-data-pointer,\
-readability-container-size-empty,\
-readability-convert-member-functions-to-static,\
-readability-else-after-return,\
-readability-function-cognitive-complexity,\
-readability-identifier-length,\
-readability-implicit-bool-conversion,\
-readability-inconsistent-declaration-parameter-name,\
-readability-isolate-declaration,\
-readability-magic-numbers,\
-readability-make-member-function-const,\
-readability-named-parameter,\
-readability-non-const-parameter,\
-readability-qualified-auto,\
-readability-redundant-access-specifiers,\
-readability-redundant-control-flow,\
-readability-redundant-declaration,\
-readability-redundant-member-init,\
-readability-redundant-smartptr-get,\
-readability-redundant-string-cstr,\
-readability-redundant-string-init,\
-readability-simplify-boolean-expr,\
-readability-static-accessed-through-instance,\
-readability-static-definition-in-anonymous-namespace,\
-readability-suspicious-call-argument,\
-readability-uppercase-literal-suffix,\
-readability-use-anyofallof
"

View File

@ -12,6 +12,7 @@ jobs:
strategy:
fail-fast: false
matrix:
# When adding another, also add to copybara's github_check_runs.
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
build_type: ['Release']
preset: ['make', 'windows']
@ -43,7 +44,7 @@ jobs:
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
- name: Build
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }}
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4
- name: Archive production artifacts
uses: actions/upload-artifact@v4
@ -54,3 +55,21 @@ jobs:
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
${{ github.workspace }}/build/gemma
${{ github.workspace }}/build/libgemma.a
bazel:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0
with:
egress-policy: audit # cannot be block - runner does git checkout
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.0.0
- uses: bazelbuild/setup-bazelisk@b39c379c82683a5f25d34f0d062761f62693e0b2 # v3.0.0
- uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # v4.0.1
with:
path: ~/.cache/bazel
key: bazel-${{ runner.os }}
- run: bazel build -c opt --cxxopt=-std=c++20 //...

View File

@ -25,21 +25,14 @@ cc_library(
],
deps = [
"//compression:compress",
# copybara:import_next_line:hwy
"//:algo",
# copybara:import_next_line:hwy
"//:dot",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:math",
# copybara:import_next_line:hwy
"//:matvec",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
"//hwy/contrib/sort:vqsort",
"@hwy//:algo",
"@hwy//:dot",
"@hwy//:hwy",
"@hwy//:math",
"@hwy//:matvec",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@hwy//hwy/contrib/sort:vqsort",
],
)
@ -49,8 +42,7 @@ cc_library(
"util/args.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -61,8 +53,7 @@ cc_library(
],
deps = [
":args",
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -78,19 +69,13 @@ cc_library(
deps = [
":args",
":transformer_ops",
"//base",
"//compression:compress",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:matvec",
# copybara:import_next_line:hwy
"//:nanobenchmark", # timer
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
":sentencepiece_processor",
"@hwy//:hwy",
"@hwy//:matvec",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@hwy//:thread_pool",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
@ -104,13 +89,9 @@ cc_binary(
":args",
":gemma_lib",
"//compression:compress",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:nanobenchmark",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:thread_pool",
],
)

View File

@ -20,6 +20,7 @@ project(gemma)
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway)

View File

@ -127,13 +127,13 @@ working with weights, kv cache and activations (e.g. you might have multiple kv
caches and activations for a single set of weights) more directly rather than
only using a Gemma object.
## Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
You pretty much only do things with the tokenizer, call `Encode()` to go from
string prompts to token id vectors, or `Decode()` to go from token id vector
outputs from the model back to strings.
## The main entrypoint for generation is `GenerateGemma()`
### The main entrypoint for generation is `GenerateGemma()`
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
activation values in `model` and 2) invoke StreamFunc - a lambda callback for
@ -150,7 +150,7 @@ constrained decoding type of use cases where you want to force the generation
to fit a grammar. If you're not doing this, you can send an empty lambda as a
no-op which is what `run.cc` does.
## If you want to invoke the neural network forward function directly call the `Transformer()` function
### If you want to invoke the neural network forward function directly call the `Transformer()` function
For high-level applications, you might only call `GenerateGemma()` and never
interact directly with the neural network, but if you're doing something a bit
@ -158,11 +158,20 @@ more custom you can call transformer which performs a single inference
operation on a single token and mutates the Activations and the KVCache through
the neural network computation.
## For low level operations, defining new architectures, call `ops.h` functions directly
### For low level operations, defining new architectures, call `ops.h` functions directly
You use `ops.h` if you're writing other NN architectures or modifying the
inference path of the Gemma model.
## Building with Bazel
The sentencepiece library we depend on requires some additional work to build
with the Bazel build system. First, it does not export its BUILD file, so we
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
support Abseil as a standalone dependency without third_party/ prefixes, similar
to the transforms we apply to Gemma via Copybara.
## Discord
We're also trying out a discord server for discussion here -

View File

@ -3,12 +3,57 @@ module(
version = "0.1.0",
)
bazel_dep(
name = "rules_license",
version = "0.0.7",
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "googletest", version = "1.14.0")
# Copied from Highway because Bazel does not load them transitively
bazel_dep(name = "bazel_skylib", version = "1.4.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "platforms", version = "0.0.7")
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "hwy",
urls = ["https://github.com/google/highway/archive/refs/tags/1.1.0.zip"],
integrity = "sha256-zkJX2SwL4wQ0nHMsURW7MDLEf43vFSnqhSUsUM6eQmY=",
strip_prefix = "highway-1.1.0",
)
bazel_dep(
http_archive(
name = "com_google_sentencepiece",
version = "0.1.96",
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:com_google_sentencepiece.patch"],
patch_args = ["-p1"],
)
# For sentencepiece
http_archive(
name = "darts_clone",
build_file_content = """
licenses(["notice"])
exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "darts_clone",
hdrs = [
"include/darts.h",
],
)
""",
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
urls = [
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)
# ABSL on 2023-10-18
http_archive(
name = "com_google_absl",
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
)

View File

@ -65,15 +65,26 @@ winget install --id Kitware.CMake
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
```
### Step 1: Obtain model weights and tokenizer from Kaggle
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
Visit [the Gemma model page on
Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations
Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp) and select `Model Variations
|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below.
Note bfloat16 weights are higher fidelity, while 8-bit switched floating point
weights enable faster inference. In general, we recommend starting with the
`-sfp` checkpoints.
Alternatively, visit the [gemma.cpp](https://huggingface.co/models?other=gemma.cpp)
models on the Hugging Face Hub. First go the the model repository of the model of interest
(see recommendations below). Then, click the `Files and versions` tab and download the
model and tokenizer files. For programmatic downloading, if you have `huggingface_hub`
installed, you can also download by running:
```
huggingface-cli login # Just the first time
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
```
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description |
@ -98,6 +109,8 @@ weights enable faster inference. In general, we recommend starting with the
### Step 2: Extract Files
If you downloaded the models from Hugging Face, skip to step 3.
After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes):
@ -175,6 +188,14 @@ cmake --build --preset windows -j [number of parallel threads to use]
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
#### Bazel
```sh
bazel build -c opt --cxxopt=-std=c++20 :gemma
```
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
### Step 4: Run
You can now run `gemma` from inside the `build/` directory.

View File

@ -1,24 +1,4 @@
workspace(name = "gemma")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
maybe(
http_archive,
name = "rules_license",
sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360",
urls = [
"https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz",
],
)
maybe(
http_archive,
name = "com_google_sentencepiece",
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//third_party:sentencepiece.bazel",
patches = ["@//third_party:com_google_sentencepiece.patch"],
patch_args = ["-p1"],
)
# This file marks the root of the Bazel workspace.
# See MODULE.bazel for external dependencies setup.

4
bazel/BUILD Normal file
View File

@ -0,0 +1,4 @@
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"],
)

File diff suppressed because it is too large Load Diff

97
bazel/sentencepiece.bazel Normal file
View File

@ -0,0 +1,97 @@
package(
default_visibility = ["//visibility:public"],
features = [
"layering_check",
"parse_headers",
],
)
licenses(["notice"]) # Apache 2, BSD, MIT
proto_library(
name = "sentencepiece_proto",
srcs = ["src/sentencepiece.proto"],
)
cc_proto_library(
name = "sentencepiece_cc_proto",
deps = [":sentencepiece_proto"],
)
proto_library(
name = "sentencepiece_model_proto",
srcs = ["src/sentencepiece_model.proto"],
)
cc_proto_library(
name = "sentencepiece_model_cc_proto",
deps = [":sentencepiece_model_proto"],
)
genrule(
name = "config_h",
srcs = ["config.h.in"],
outs = ["config.h"],
cmd = "cp $< $@",
)
cc_library(
name = "common",
hdrs = [
"config.h",
"src/common.h",
],
deps = [
"@com_google_absl//absl/base",
],
)
cc_library(
name = "sentencepiece_processor",
srcs = [
"src/bpe_model.cc",
"src/char_model.cc",
"src/error.cc",
"src/filesystem.cc",
"src/model_factory.cc",
"src/model_interface.cc",
"src/normalizer.cc",
"src/sentencepiece_processor.cc",
"src/unigram_model.cc",
"src/util.cc",
"src/word_model.cc",
],
hdrs = [
"src/bpe_model.h",
"src/char_model.h",
"src/filesystem.h",
"src/freelist.h",
"src/model_factory.h",
"src/model_interface.h",
"src/normalizer.h",
"src/sentencepiece_processor.h",
"src/trainer_interface.h",
"src/unigram_model.h",
"src/util.h",
"src/word_model.h",
],
defines = ["_USE_TF_STRING_VIEW"],
includes = [
".",
"src",
],
linkstatic = 1,
deps =
[
":common",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@darts_clone",
],
)

View File

@ -1,10 +1,10 @@
# Weight compression, I/O and analysis
package(
default_applicable_licenses = ["//third_party/gemma_cpp:license"],
default_applicable_licenses = ["//:license"],
default_visibility = [
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__",
"//third_party/gemma_cpp:__subpackages__",
"//:__subpackages__",
],
)
@ -17,10 +17,8 @@ cc_library(
"blob_store.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:thread_pool",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)
@ -34,8 +32,7 @@ cc_library(
"stats.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -48,8 +45,7 @@ cc_library(
"sfp-inl.h",
],
deps = [
# copybara:import_next_line:hwy
"//:hwy",
"@hwy//:hwy",
],
)
@ -65,15 +61,11 @@ cc_test(
deps = [
":sfp",
":stats",
"//testing/base/public:gunit_main_no_google3",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:hwy_test_util",
# copybara:import_next_line:hwy
"//:nanobenchmark",
# copybara:import_next_line:hwy
"//:thread_pool",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)
@ -87,9 +79,8 @@ cc_library(
],
deps = [
":sfp",
# copybara:import_next_line:hwy
"//:hwy",
"//third_party/highway/hwy/contrib/sort:vqsort",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
],
)
@ -106,13 +97,10 @@ cc_test(
":nuq",
":sfp",
":stats",
"//testing/base/public:gunit_main_no_google3",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:hwy_test_util",
# copybara:import_next_line:hwy
"//:nanobenchmark",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
],
)
@ -131,12 +119,9 @@ cc_library(
":nuq",
":sfp",
":stats",
# copybara:import_next_line:hwy
"//:dot",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:thread_pool",
"@hwy//:dot",
"@hwy//:hwy",
"@hwy//:thread_pool",
],
)
@ -150,12 +135,9 @@ cc_library(
":nuq",
":sfp",
":stats",
# copybara:import_next_line:hwy
"//:hwy",
# copybara:import_next_line:hwy
"//:nanobenchmark", # timer
# copybara:import_next_line:hwy
"//:thread_pool",
"//third_party/highway/hwy/contrib/sort:vqsort",
"@hwy//:hwy",
"@hwy//:nanobenchmark", # timer
"@hwy//:thread_pool",
"@hwy//hwy/contrib/sort:vqsort",
],
)

View File

@ -149,7 +149,7 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
size_t remaining = num - i;
const size_t remaining = num - i;
if (remaining != 0) {
const VF in0 = hn::LoadN(df, in + i, remaining);
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
@ -195,7 +195,7 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
size_t remaining = num - i;
const size_t remaining = num - i;
if (remaining != 0) {
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
const VF in0 = hn::PromoteLowerTo(df, in16);
@ -287,7 +287,7 @@ struct CompressTraits<NuqStream> {
if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(in[i] * 100 + 500);
tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
}
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
});
const double t1 = hwy::platform::Now();
const double mb = num * sizeof(in[0]) * 1E-6;
const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
const double mbps = mb / (t1 - t0);
fprintf(stderr, "Compress %.1f MB/s\n", mbps);

View File

@ -68,15 +68,15 @@ class DistortionStats {
double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0;
return exp(sum_log_rel_ / num_rel_);
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
}
double PNorm() const {
// p-norms are a compromise between max-norm (penalizes the largest error
// without dilution, but does not notice any other errors) and L1 (all
// errors contribute, but large errors are diluted by smaller ones).
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
return 0.5 * (norm3 + norm6);
}

View File

@ -87,7 +87,7 @@ class NuqClustering {
inv_len_[0] = 0.0f; // unused
for (size_t i = 0; i <= kGroupSize; ++i) {
inv_len_[i] = 1.0f / i;
inv_len_[i] = 1.0f / static_cast<float>(i);
}
}
@ -229,7 +229,7 @@ class NuqClustering {
const float sum = cc.SumOfSorted(start, last);
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
HWY_DASSERT(0 < size && size <= kGroupSize);
centers[k] = sum / size;
centers[k] = sum / static_cast<float>(size);
// We know the range inside sorted_and_i[]; translate to original indices,
// which are stored inside each of the sorted_and_i mantissas.
@ -470,9 +470,7 @@ class NuqCodec {
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
ClusterBuf& buf, const size_t out_capacity,
NuqStream* const out, const size_t out_ofs) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<uint16_t, DF> d16;
using V8 = hn::Vec<decltype(d8)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t N16 = hn::Lanes(d16);

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
@ -23,9 +28,10 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT
#define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
// copybara:import_next_line:gemma_cpp
#include "compression/sfp.h"
@ -27,9 +32,10 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp
@ -301,7 +307,7 @@ struct TestEncDec {
for (size_t i = 0; i < num; ++i) {
const float out = hwy::F32FromBF16(dec[i]);
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
stats.Notify(in[i], out);
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
}
const double avg = sum / num;
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",

View File

@ -527,7 +527,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos
double prefill_start = hwy::platform::Now();
const double prefill_start = hwy::platform::Now();
// Prefill stops before prompt.size() - 1 since the last prompt token is the
// first input token for generation.
@ -549,12 +549,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O.
double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
}
double gen_start = hwy::platform::Now();
const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt.size() - 1);
@ -592,10 +593,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
}
if (token == EOS_ID) {
if (verbosity >= 2) {
double gen_end = hwy::platform::Now();
const double gen_end = hwy::platform::Now();
const double gen_tok_sec =
(pos_offset - pos_gen_start) / (gen_end - gen_start);
std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
}
@ -693,7 +695,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
if (loader.ReadAll(pool)) return c_weights_u8;
// Get weights, compress, and store in cache.
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
LoadWeights<TConfig>(model);
Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str());

56
ops.h
View File

@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() {
return 2048;
}
template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
StaticCast(From from) noexcept {
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
return static_cast<To>(
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
else
return static_cast<To>(from);
}
template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one
@ -230,7 +241,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
size_t i = 0;
if (size >= 2 * NF) {
for (; i < size - 2 * NF; i += 2 * NF) {
for (; i <= size - 2 * NF; i += 2 * NF) {
const VF mul0 = hn::LoadU(df, mul + i);
const VF mul1 = hn::LoadU(df, mul + i + NF);
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]);
@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
constexpr float eps = 1e-6f;
float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) {
@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
const size_t num_timescales = dim_model / 2;
const float log_timescale_increment =
logf(10000.0f) /
(num_timescales != 0
? static_cast<float>(static_cast<int>(num_timescales) - 1)
: 1.0f);
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
for (size_t dim = 0; dim < num_timescales; ++dim) {
const float inv_timescale =
expf(static_cast<int>(dim) * -log_timescale_increment);
x[dim] += sinf(pos * inv_timescale);
x[num_timescales + dim] += cosf(pos * inv_timescale);
expf(StaticCast<float>(dim) * -log_timescale_increment);
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
}
}
@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
static_cast<float>(dim_qkv);
const float freq_exponents =
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale;
const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta);
const float sin_val = sinf(theta);
const float x0 = x[dim];
@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
continue;
}
for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
// shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1];
}
top_k[j] = probabilities[i];
indices[j] = static_cast<int>(i);
indices[j] = StaticCast<int>(i);
break;
}
}