[WIP] dev/examples branch merge

This commit is contained in:
austinvhuang 2024-03-06 15:10:48 -05:00
commit 5b9d8a9936
21 changed files with 2872 additions and 161 deletions

1
.bazelrc Normal file
View File

@ -0,0 +1 @@
common --enable_bzlmod

View File

@ -1 +1,2 @@
Language: Cpp
BasedOnStyle: Google BasedOnStyle: Google

206
.clang-tidy Normal file
View File

@ -0,0 +1,206 @@
FormatStyle: file
Checks: "-*,\
abseil-*,\
-abseil-string-find-startswith,\
-abseil-string-find-str-contains,\
bugprone-*,\
-bugprone-argument-comment,\
-bugprone-assert-side-effect,\
-bugprone-bad-signal-to-kill-thread,\
-bugprone-bool-pointer-implicit-conversion,\
-bugprone-branch-clone,\
-bugprone-copy-constructor-init,\
-bugprone-dangling-handle,\
-bugprone-dynamic-static-initializers,\
-bugprone-easily-swappable-parameters,\
-bugprone-exception-escape,\
-bugprone-fold-init-type,\
-bugprone-forward-declaration-namespace,\
-bugprone-forwarding-reference-overload,\
-bugprone-implicit-widening-of-multiplication-result,\
-bugprone-inaccurate-erase,\
-bugprone-incorrect-roundings,\
-bugprone-infinite-loop,\
-bugprone-integer-division,\
-bugprone-lambda-function-name,\
-bugprone-macro-parentheses,\
-bugprone-macro-repeated-side-effects,\
-bugprone-misplaced-operator-in-strlen-in-alloc,\
-bugprone-misplaced-widening-cast,\
-bugprone-move-forwarding-reference,\
-bugprone-multiple-statement-macro,\
-bugprone-narrowing-conversions,\
-bugprone-no-escape,\
-bugprone-not-null-terminated-result,\
-bugprone-parent-virtual-call,\
-bugprone-posix-return,\
-bugprone-redundant-branch-condition,\
-bugprone-reserved-identifier,\
-bugprone-signal-handler,\
-bugprone-signed-char-misuse,\
-bugprone-sizeof-container,\
-bugprone-sizeof-expression,\
-bugprone-spuriously-wake-up-functions,\
-bugprone-string-constructor,\
-bugprone-string-integer-assignment,\
-bugprone-string-literal-with-embedded-nul,\
-bugprone-stringview-nullptr,\
-bugprone-suspicious-enum-usage,\
-bugprone-suspicious-include,\
-bugprone-suspicious-memory-comparison,\
-bugprone-suspicious-memset-usage,\
-bugprone-suspicious-missing-comma,\
-bugprone-suspicious-semicolon,\
-bugprone-suspicious-string-compare,\
-bugprone-swapped-arguments,\
-bugprone-terminating-continue,\
-bugprone-throw-keyword-missing,\
-bugprone-too-small-loop-variable,\
-bugprone-undefined-memory-manipulation,\
-bugprone-undelegated-constructor,\
-bugprone-unhandled-exception-at-new,\
-bugprone-unhandled-self-assignment,\
-bugprone-unused-raii,\
-bugprone-unused-return-value,\
-bugprone-use-after-move,\
-bugprone-virtual-near-miss,\
cert-*,\
-cert-dcl16-c,\
-cert-dcl21-cpp,\
-cert-dcl37-c,\
-cert-dcl50-cpp,\
-cert-dcl51-cpp,\
-cert-dcl54-cpp,\
-cert-dcl58-cpp,\
-cert-err33-c,\
-cert-msc30-c,\
-cert-msc32-c,\
-cert-msc50-cpp,\
-cert-msc51-cpp,\
-cert-oop54-cpp,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-cert-str34-c,\
-clang-analyzer-*,\
concurrency-*,\
-concurrency-mt-unsafe,\
cppcoreguidelines-*,\
-concurrency-mt-unsafe,\
-cppcoreguidelines-avoid-c-arrays,\
-cppcoreguidelines-avoid-const-or-ref-data-members,\
-cppcoreguidelines-avoid-goto,\
-cppcoreguidelines-avoid-magic-numbers,\
-cppcoreguidelines-avoid-non-const-global-variables,\
-cppcoreguidelines-c-copy-assignment-signature,\
-cppcoreguidelines-explicit-virtual-functions,\
-cppcoreguidelines-init-variables,\
-cppcoreguidelines-interfaces-global-init,\
-cppcoreguidelines-macro-usage,\
-cppcoreguidelines-narrowing-conversions,\
-cppcoreguidelines-no-malloc,\
-cppcoreguidelines-non-private-member-variables-in-classes,\
-cppcoreguidelines-owning-memory,\
-cppcoreguidelines-prefer-member-initializer,\
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,\
-cppcoreguidelines-pro-bounds-constant-array-index,\
-cppcoreguidelines-pro-bounds-pointer-arithmetic,\
-cppcoreguidelines-pro-type-const-cast,\
-cppcoreguidelines-pro-type-member-init,\
-cppcoreguidelines-pro-type-reinterpret-cast,\
-cppcoreguidelines-pro-type-static-cast-downcast,\
-cppcoreguidelines-pro-type-union-access,\
-cppcoreguidelines-pro-type-vararg,\
-cppcoreguidelines-slicing,\
-cppcoreguidelines-special-member-functions,\
-cppcoreguidelines-virtual-class-destructor,\
google-*,\
-google-default-arguments,\
-google-explicit-constructor,\
-google-readability-avoid-underscore-in-googletest-name,\
-google-readability-braces-around-statements,\
-google-readability-casting,\
-google-readability-namespace-comments,\
-google-readability-todo,\
-google-runtime-int,\
-google-upgrade-googletest-case,\
misc-*,\
-misc-misplaced-const,\
-misc-new-delete-overloads,\
-misc-non-private-member-variables-in-classes,\
-misc-no-recursion,\
-misc-redundant-expression,\
-misc-uniqueptr-reset-release,\
-misc-unconventional-assign-operator,\
-misc-unused-parameters,\
-misc-unused-using-decls,\
modernize-*,\
-modernize-avoid-c-arrays,\
-modernize-concat-nested-namespaces,\
-modernize-deprecated-headers,\
-modernize-loop-convert,\
-modernize-macro-to-enum,\
-modernize-make-unique,\
-modernize-pass-by-value,\
-modernize-raw-string-literal,\
-modernize-redundant-void-arg,\
-modernize-return-braced-init-list,\
-modernize-unary-static-assert,\
-modernize-use-auto,\
-modernize-use-bool-literals,\
-modernize-use-default-member-init,\
-modernize-use-emplace,\
-modernize-use-equals-default,\
-modernize-use-equals-delete,\
-modernize-use-nodiscard,\
-modernize-use-nullptr,\
-modernize-use-override,\
-modernize-use-trailing-return-type,\
-modernize-use-transparent-functors,\
-modernize-use-using,\
performance-*,\
-performance-faster-string-find,\
-performance-for-range-copy,\
-performance-inefficient-algorithm,\
-performance-inefficient-string-concatenation,\
-performance-inefficient-vector-operation,\
-performance-move-const-arg,\
-performance-no-automatic-move,\
-performance-noexcept-move-constructor,\
-performance-no-int-to-ptr,\
-performance-trivially-destructible,\
-performance-unnecessary-copy-initialization,\
-performance-unnecessary-value-param,\
portability-*,\
readability-*,\
-readability-avoid-const-params-in-decls,\
-readability-braces-around-statements,\
-readability-const-return-type,\
-readability-container-data-pointer,\
-readability-container-size-empty,\
-readability-convert-member-functions-to-static,\
-readability-else-after-return,\
-readability-function-cognitive-complexity,\
-readability-identifier-length,\
-readability-implicit-bool-conversion,\
-readability-inconsistent-declaration-parameter-name,\
-readability-isolate-declaration,\
-readability-magic-numbers,\
-readability-make-member-function-const,\
-readability-named-parameter,\
-readability-non-const-parameter,\
-readability-qualified-auto,\
-readability-redundant-access-specifiers,\
-readability-redundant-control-flow,\
-readability-redundant-declaration,\
-readability-redundant-member-init,\
-readability-redundant-smartptr-get,\
-readability-redundant-string-cstr,\
-readability-redundant-string-init,\
-readability-simplify-boolean-expr,\
-readability-static-accessed-through-instance,\
-readability-static-definition-in-anonymous-namespace,\
-readability-suspicious-call-argument,\
-readability-uppercase-literal-suffix,\
-readability-use-anyofallof
"

View File

@ -12,6 +12,7 @@ jobs:
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
# When adding another, also add to copybara's github_check_runs.
os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
build_type: ['Release'] build_type: ['Release']
preset: ['make', 'windows'] preset: ['make', 'windows']
@ -43,7 +44,7 @@ jobs:
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache -D CMAKE_CXX_COMPILER_LAUNCHER=ccache
- name: Build - name: Build
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4
- name: Archive production artifacts - name: Archive production artifacts
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
@ -54,3 +55,21 @@ jobs:
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib ${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
${{ github.workspace }}/build/gemma ${{ github.workspace }}/build/gemma
${{ github.workspace }}/build/libgemma.a ${{ github.workspace }}/build/libgemma.a
bazel:
runs-on: ubuntu-latest
steps:
- name: Harden Runner
uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0
with:
egress-policy: audit # cannot be block - runner does git checkout
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.0.0
- uses: bazelbuild/setup-bazelisk@b39c379c82683a5f25d34f0d062761f62693e0b2 # v3.0.0
- uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # v4.0.1
with:
path: ~/.cache/bazel
key: bazel-${{ runner.os }}
- run: bazel build -c opt --cxxopt=-std=c++20 //...

View File

@ -25,21 +25,14 @@ cc_library(
], ],
deps = [ deps = [
"//compression:compress", "//compression:compress",
# copybara:import_next_line:hwy "@hwy//:algo",
"//:algo", "@hwy//:dot",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:dot", "@hwy//:math",
# copybara:import_next_line:hwy "@hwy//:matvec",
"//:hwy", "@hwy//:profiler",
# copybara:import_next_line:hwy "@hwy//:thread_pool",
"//:math", "@hwy//hwy/contrib/sort:vqsort",
# copybara:import_next_line:hwy
"//:matvec",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
"//hwy/contrib/sort:vqsort",
], ],
) )
@ -49,8 +42,7 @@ cc_library(
"util/args.h", "util/args.h",
], ],
deps = [ deps = [
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy",
], ],
) )
@ -61,8 +53,7 @@ cc_library(
], ],
deps = [ deps = [
":args", ":args",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy",
], ],
) )
@ -78,19 +69,13 @@ cc_library(
deps = [ deps = [
":args", ":args",
":transformer_ops", ":transformer_ops",
"//base",
"//compression:compress", "//compression:compress",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//:matvec",
# copybara:import_next_line:hwy "@hwy//:nanobenchmark", # timer
"//:matvec", "@hwy//:profiler",
# copybara:import_next_line:hwy "@hwy//:thread_pool",
"//:nanobenchmark", # timer "@com_google_sentencepiece//:sentencepiece_processor",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
":sentencepiece_processor",
], ],
) )
@ -104,13 +89,9 @@ cc_binary(
":args", ":args",
":gemma_lib", ":gemma_lib",
"//compression:compress", "//compression:compress",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//:nanobenchmark",
# copybara:import_next_line:hwy "@hwy//:profiler",
"//:nanobenchmark", "@hwy//:thread_pool",
# copybara:import_next_line:hwy
"//:profiler",
# copybara:import_next_line:hwy
"//:thread_pool",
], ],
) )

View File

@ -20,6 +20,7 @@ project(gemma)
set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f) FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
FetchContent_MakeAvailable(highway) FetchContent_MakeAvailable(highway)

View File

@ -127,13 +127,13 @@ working with weights, kv cache and activations (e.g. you might have multiple kv
caches and activations for a single set of weights) more directly rather than caches and activations for a single set of weights) more directly rather than
only using a Gemma object. only using a Gemma object.
## Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly) ### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
You pretty much only do things with the tokenizer, call `Encode()` to go from You pretty much only do things with the tokenizer, call `Encode()` to go from
string prompts to token id vectors, or `Decode()` to go from token id vector string prompts to token id vectors, or `Decode()` to go from token id vector
outputs from the model back to strings. outputs from the model back to strings.
## The main entrypoint for generation is `GenerateGemma()` ### The main entrypoint for generation is `GenerateGemma()`
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
activation values in `model` and 2) invoke StreamFunc - a lambda callback for activation values in `model` and 2) invoke StreamFunc - a lambda callback for
@ -150,7 +150,7 @@ constrained decoding type of use cases where you want to force the generation
to fit a grammar. If you're not doing this, you can send an empty lambda as a to fit a grammar. If you're not doing this, you can send an empty lambda as a
no-op which is what `run.cc` does. no-op which is what `run.cc` does.
## If you want to invoke the neural network forward function directly call the `Transformer()` function ### If you want to invoke the neural network forward function directly call the `Transformer()` function
For high-level applications, you might only call `GenerateGemma()` and never For high-level applications, you might only call `GenerateGemma()` and never
interact directly with the neural network, but if you're doing something a bit interact directly with the neural network, but if you're doing something a bit
@ -158,11 +158,20 @@ more custom you can call transformer which performs a single inference
operation on a single token and mutates the Activations and the KVCache through operation on a single token and mutates the Activations and the KVCache through
the neural network computation. the neural network computation.
## For low level operations, defining new architectures, call `ops.h` functions directly ### For low level operations, defining new architectures, call `ops.h` functions directly
You use `ops.h` if you're writing other NN architectures or modifying the You use `ops.h` if you're writing other NN architectures or modifying the
inference path of the Gemma model. inference path of the Gemma model.
## Building with Bazel
The sentencepiece library we depend on requires some additional work to build
with the Bazel build system. First, it does not export its BUILD file, so we
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
support Abseil as a standalone dependency without third_party/ prefixes, similar
to the transforms we apply to Gemma via Copybara.
## Discord ## Discord
We're also trying out a discord server for discussion here - We're also trying out a discord server for discussion here -

View File

@ -3,12 +3,57 @@ module(
version = "0.1.0", version = "0.1.0",
) )
bazel_dep( bazel_dep(name = "rules_license", version = "0.0.7")
name = "rules_license", bazel_dep(name = "googletest", version = "1.14.0")
version = "0.0.7",
# Copied from Highway because Bazel does not load them transitively
bazel_dep(name = "bazel_skylib", version = "1.4.1")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "platforms", version = "0.0.7")
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "hwy",
urls = ["https://github.com/google/highway/archive/refs/tags/1.1.0.zip"],
integrity = "sha256-zkJX2SwL4wQ0nHMsURW7MDLEf43vFSnqhSUsUM6eQmY=",
strip_prefix = "highway-1.1.0",
) )
bazel_dep( http_archive(
name = "com_google_sentencepiece", name = "com_google_sentencepiece",
version = "0.1.96", sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:com_google_sentencepiece.patch"],
patch_args = ["-p1"],
)
# For sentencepiece
http_archive(
name = "darts_clone",
build_file_content = """
licenses(["notice"])
exports_files(["LICENSE"])
package(default_visibility = ["//visibility:public"])
cc_library(
name = "darts_clone",
hdrs = [
"include/darts.h",
],
)
""",
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
urls = [
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)
# ABSL on 2023-10-18
http_archive(
name = "com_google_absl",
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
) )

View File

@ -65,15 +65,26 @@ winget install --id Kitware.CMake
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset" winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
``` ```
### Step 1: Obtain model weights and tokenizer from Kaggle ### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
Visit [the Gemma model page on Visit [the Gemma model page on
Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp) and select `Model Variations
|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below. |> Gemma C++`. On this tab, the `Variation` dropdown includes the options below.
Note bfloat16 weights are higher fidelity, while 8-bit switched floating point Note bfloat16 weights are higher fidelity, while 8-bit switched floating point
weights enable faster inference. In general, we recommend starting with the weights enable faster inference. In general, we recommend starting with the
`-sfp` checkpoints. `-sfp` checkpoints.
Alternatively, visit the [gemma.cpp](https://huggingface.co/models?other=gemma.cpp)
models on the Hugging Face Hub. First go the the model repository of the model of interest
(see recommendations below). Then, click the `Files and versions` tab and download the
model and tokenizer files. For programmatic downloading, if you have `huggingface_hub`
installed, you can also download by running:
```
huggingface-cli login # Just the first time
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
```
2B instruction-tuned (`it`) and pre-trained (`pt`) models: 2B instruction-tuned (`it`) and pre-trained (`pt`) models:
| Model name | Description | | Model name | Description |
@ -98,6 +109,8 @@ weights enable faster inference. In general, we recommend starting with the
### Step 2: Extract Files ### Step 2: Extract Files
If you downloaded the models from Hugging Face, skip to step 3.
After filling out the consent form, the download should proceed to retrieve a After filling out the consent form, the download should proceed to retrieve a
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
take a few minutes): take a few minutes):
@ -175,6 +188,14 @@ cmake --build --preset windows -j [number of parallel threads to use]
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory. If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
#### Bazel
```sh
bazel build -c opt --cxxopt=-std=c++20 :gemma
```
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
### Step 4: Run ### Step 4: Run
You can now run `gemma` from inside the `build/` directory. You can now run `gemma` from inside the `build/` directory.

View File

@ -1,24 +1,4 @@
workspace(name = "gemma") workspace(name = "gemma")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") # This file marks the root of the Bazel workspace.
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe") # See MODULE.bazel for external dependencies setup.
maybe(
http_archive,
name = "rules_license",
sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360",
urls = [
"https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz",
],
)
maybe(
http_archive,
name = "com_google_sentencepiece",
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//third_party:sentencepiece.bazel",
patches = ["@//third_party:com_google_sentencepiece.patch"],
patch_args = ["-p1"],
)

4
bazel/BUILD Normal file
View File

@ -0,0 +1,4 @@
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//:__subpackages__"],
)

File diff suppressed because it is too large Load Diff

97
bazel/sentencepiece.bazel Normal file
View File

@ -0,0 +1,97 @@
package(
default_visibility = ["//visibility:public"],
features = [
"layering_check",
"parse_headers",
],
)
licenses(["notice"]) # Apache 2, BSD, MIT
proto_library(
name = "sentencepiece_proto",
srcs = ["src/sentencepiece.proto"],
)
cc_proto_library(
name = "sentencepiece_cc_proto",
deps = [":sentencepiece_proto"],
)
proto_library(
name = "sentencepiece_model_proto",
srcs = ["src/sentencepiece_model.proto"],
)
cc_proto_library(
name = "sentencepiece_model_cc_proto",
deps = [":sentencepiece_model_proto"],
)
genrule(
name = "config_h",
srcs = ["config.h.in"],
outs = ["config.h"],
cmd = "cp $< $@",
)
cc_library(
name = "common",
hdrs = [
"config.h",
"src/common.h",
],
deps = [
"@com_google_absl//absl/base",
],
)
cc_library(
name = "sentencepiece_processor",
srcs = [
"src/bpe_model.cc",
"src/char_model.cc",
"src/error.cc",
"src/filesystem.cc",
"src/model_factory.cc",
"src/model_interface.cc",
"src/normalizer.cc",
"src/sentencepiece_processor.cc",
"src/unigram_model.cc",
"src/util.cc",
"src/word_model.cc",
],
hdrs = [
"src/bpe_model.h",
"src/char_model.h",
"src/filesystem.h",
"src/freelist.h",
"src/model_factory.h",
"src/model_interface.h",
"src/normalizer.h",
"src/sentencepiece_processor.h",
"src/trainer_interface.h",
"src/unigram_model.h",
"src/util.h",
"src/word_model.h",
],
defines = ["_USE_TF_STRING_VIEW"],
includes = [
".",
"src",
],
linkstatic = 1,
deps =
[
":common",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@darts_clone",
],
)

View File

@ -1,10 +1,10 @@
# Weight compression, I/O and analysis # Weight compression, I/O and analysis
package( package(
default_applicable_licenses = ["//third_party/gemma_cpp:license"], default_applicable_licenses = ["//:license"],
default_visibility = [ default_visibility = [
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__", "//learning/gemini/prod/contrib/gemini_cpp:__subpackages__",
"//third_party/gemma_cpp:__subpackages__", "//:__subpackages__",
], ],
) )
@ -17,10 +17,8 @@ cc_library(
"blob_store.h", "blob_store.h",
], ],
deps = [ deps = [
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//:thread_pool",
# copybara:import_next_line:hwy
"//:thread_pool",
], ],
) )
@ -34,8 +32,7 @@ cc_library(
"stats.h", "stats.h",
], ],
deps = [ deps = [
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy",
], ],
) )
@ -48,8 +45,7 @@ cc_library(
"sfp-inl.h", "sfp-inl.h",
], ],
deps = [ deps = [
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy",
], ],
) )
@ -65,15 +61,11 @@ cc_test(
deps = [ deps = [
":sfp", ":sfp",
":stats", ":stats",
"//testing/base/public:gunit_main_no_google3", "@googletest//:gtest_main",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//:hwy_test_util",
# copybara:import_next_line:hwy "@hwy//:nanobenchmark",
"//:hwy_test_util", "@hwy//:thread_pool",
# copybara:import_next_line:hwy
"//:nanobenchmark",
# copybara:import_next_line:hwy
"//:thread_pool",
], ],
) )
@ -87,9 +79,8 @@ cc_library(
], ],
deps = [ deps = [
":sfp", ":sfp",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//hwy/contrib/sort:vqsort",
"//third_party/highway/hwy/contrib/sort:vqsort",
], ],
) )
@ -106,13 +97,10 @@ cc_test(
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",
"//testing/base/public:gunit_main_no_google3", "@googletest//:gtest_main",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//:hwy_test_util",
# copybara:import_next_line:hwy "@hwy//:nanobenchmark",
"//:hwy_test_util",
# copybara:import_next_line:hwy
"//:nanobenchmark",
], ],
) )
@ -131,12 +119,9 @@ cc_library(
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",
# copybara:import_next_line:hwy "@hwy//:dot",
"//:dot", "@hwy//:hwy",
# copybara:import_next_line:hwy "@hwy//:thread_pool",
"//:hwy",
# copybara:import_next_line:hwy
"//:thread_pool",
], ],
) )
@ -150,12 +135,9 @@ cc_library(
":nuq", ":nuq",
":sfp", ":sfp",
":stats", ":stats",
# copybara:import_next_line:hwy "@hwy//:hwy",
"//:hwy", "@hwy//:nanobenchmark", # timer
# copybara:import_next_line:hwy "@hwy//:thread_pool",
"//:nanobenchmark", # timer "@hwy//hwy/contrib/sort:vqsort",
# copybara:import_next_line:hwy
"//:thread_pool",
"//third_party/highway/hwy/contrib/sort:vqsort",
], ],
) )

View File

@ -149,7 +149,7 @@ struct CompressTraits<hwy::bfloat16_t> {
} }
} }
size_t remaining = num - i; const size_t remaining = num - i;
if (remaining != 0) { if (remaining != 0) {
const VF in0 = hn::LoadN(df, in + i, remaining); const VF in0 = hn::LoadN(df, in + i, remaining);
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2); const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
@ -195,7 +195,7 @@ struct CompressTraits<hwy::bfloat16_t> {
} }
} }
size_t remaining = num - i; const size_t remaining = num - i;
if (remaining != 0) { if (remaining != 0) {
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining); const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
const VF in0 = hn::PromoteLowerTo(df, in16); const VF in0 = hn::PromoteLowerTo(df, in16);
@ -287,7 +287,7 @@ struct CompressTraits<NuqStream> {
if (COMPRESS_STATS) { if (COMPRESS_STATS) {
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
tls.stats.NotifyIn(in[i] * 100 + 500); tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
} }
const hn::Repartition<hwy::bfloat16_t, DF> dbf; const hn::Repartition<hwy::bfloat16_t, DF> dbf;
@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
}); });
const double t1 = hwy::platform::Now(); const double t1 = hwy::platform::Now();
const double mb = num * sizeof(in[0]) * 1E-6; const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
const double mbps = mb / (t1 - t0); const double mbps = mb / (t1 - t0);
fprintf(stderr, "Compress %.1f MB/s\n", mbps); fprintf(stderr, "Compress %.1f MB/s\n", mbps);

View File

@ -68,15 +68,15 @@ class DistortionStats {
double GeomeanValueDivL1() const { double GeomeanValueDivL1() const {
if (num_rel_ == 0) return 0.0; if (num_rel_ == 0) return 0.0;
return exp(sum_log_rel_ / num_rel_); return exp(sum_log_rel_ / static_cast<double>(num_rel_));
} }
double PNorm() const { double PNorm() const {
// p-norms are a compromise between max-norm (penalizes the largest error // p-norms are a compromise between max-norm (penalizes the largest error
// without dilution, but does not notice any other errors) and L1 (all // without dilution, but does not notice any other errors) and L1 (all
// errors contribute, but large errors are diluted by smaller ones). // errors contribute, but large errors are diluted by smaller ones).
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3); const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6); const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
return 0.5 * (norm3 + norm6); return 0.5 * (norm3 + norm6);
} }

View File

@ -87,7 +87,7 @@ class NuqClustering {
inv_len_[0] = 0.0f; // unused inv_len_[0] = 0.0f; // unused
for (size_t i = 0; i <= kGroupSize; ++i) { for (size_t i = 0; i <= kGroupSize; ++i) {
inv_len_[i] = 1.0f / i; inv_len_[i] = 1.0f / static_cast<float>(i);
} }
} }
@ -229,7 +229,7 @@ class NuqClustering {
const float sum = cc.SumOfSorted(start, last); const float sum = cc.SumOfSorted(start, last);
const int size = static_cast<int>(last) - static_cast<int>(start) + 1; const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
HWY_DASSERT(0 < size && size <= kGroupSize); HWY_DASSERT(0 < size && size <= kGroupSize);
centers[k] = sum / size; centers[k] = sum / static_cast<float>(size);
// We know the range inside sorted_and_i[]; translate to original indices, // We know the range inside sorted_and_i[]; translate to original indices,
// which are stored inside each of the sorted_and_i mantissas. // which are stored inside each of the sorted_and_i mantissas.
@ -470,9 +470,7 @@ class NuqCodec {
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num, static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
ClusterBuf& buf, const size_t out_capacity, ClusterBuf& buf, const size_t out_capacity,
NuqStream* const out, const size_t out_ofs) { NuqStream* const out, const size_t out_ofs) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<uint16_t, DF> d16; const hn::Repartition<uint16_t, DF> d16;
using V8 = hn::Vec<decltype(d8)>;
using V16 = hn::Vec<decltype(d16)>; using V16 = hn::Vec<decltype(d16)>;
const size_t N16 = hn::Lanes(d16); const size_t N16 = hn::Lanes(d16);

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h> #include <stddef.h>
#include <stdint.h> #include <stdint.h>
#include <stdio.h> #include <stdio.h>
@ -23,9 +28,10 @@
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Other headers that include Highway must come after foreach_target.h // Other headers that include Highway must come after foreach_target.h
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp

View File

@ -13,6 +13,11 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
#include "compression/sfp.h" #include "compression/sfp.h"
@ -27,9 +32,10 @@
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
// clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \ #define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
// Any highway.h must come after foreach_target.h // Any highway.h must come after foreach_target.h
// copybara:import_next_line:gemma_cpp // copybara:import_next_line:gemma_cpp
@ -301,7 +307,7 @@ struct TestEncDec {
for (size_t i = 0; i < num; ++i) { for (size_t i = 0; i < num; ++i) {
const float out = hwy::F32FromBF16(dec[i]); const float out = hwy::F32FromBF16(dec[i]);
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i])); sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
stats.Notify(in[i], out); stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
} }
const double avg = sum / num; const double avg = sum / num;
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n", fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",

View File

@ -527,7 +527,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal. // always equal.
size_t pos_offset = 0; // offset relative to pos size_t pos_offset = 0; // offset relative to pos
double prefill_start = hwy::platform::Now(); const double prefill_start = hwy::platform::Now();
// Prefill stops before prompt.size() - 1 since the last prompt token is the // Prefill stops before prompt.size() - 1 since the last prompt token is the
// first input token for generation. // first input token for generation.
@ -549,12 +549,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
if (verbosity >= 2) { if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead // in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O. // should be available as observable state for frontend code to handle I/O.
double prefill_end = hwy::platform::Now(); const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start); const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
} }
double gen_start = hwy::platform::Now(); const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt.size() - 1); HWY_DASSERT(pos_offset == prompt.size() - 1);
@ -592,10 +593,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
} }
if (token == EOS_ID) { if (token == EOS_ID) {
if (verbosity >= 2) { if (verbosity >= 2) {
double gen_end = hwy::platform::Now(); const double gen_end = hwy::platform::Now();
const double gen_tok_sec = const double gen_tok_sec =
(pos_offset - pos_gen_start) / (gen_end - gen_start); static_cast<double>(pos_offset - pos_gen_start) /
std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n"; (gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
} }
break; break;
} }
@ -693,7 +695,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
if (loader.ReadAll(pool)) return c_weights_u8; if (loader.ReadAll(pool)) return c_weights_u8;
// Get weights, compress, and store in cache. // Get weights, compress, and store in cache.
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model); const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
LoadWeights<TConfig>(model);
Compressor compressor(pool); Compressor compressor(pool);
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor); ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
compressor.WriteAll(pool, cache.path.c_str()); compressor.WriteAll(pool, cache.path.c_str());

56
ops.h
View File

@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() {
return 2048; return 2048;
} }
template <typename To, typename From>
HWY_INLINE constexpr std::enable_if_t<
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
StaticCast(From from) noexcept {
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
return static_cast<To>(
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
else
return static_cast<To>(from);
}
template <size_t kOuter> template <size_t kOuter>
HWY_INLINE constexpr size_t RowsPerStrip() { HWY_INLINE constexpr size_t RowsPerStrip() {
// Aim for 128 work items to reduce pool overhead. Must be at least one // Aim for 128 work items to reduce pool overhead. Must be at least one
@ -230,7 +241,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
size_t i = 0; size_t i = 0;
if (size >= 2 * NF) { if (size >= 2 * NF) {
for (; i < size - 2 * NF; i += 2 * NF) { for (; i <= size - 2 * NF; i += 2 * NF) {
const VF mul0 = hn::LoadU(df, mul + i); const VF mul0 = hn::LoadU(df, mul + i);
const VF mul1 = hn::LoadU(df, mul + i + NF); const VF mul1 = hn::LoadU(df, mul + i + NF);
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i))); const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) { float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size); float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here // Note 1.0f centering here
out[j] = (1.0f + weight[j]) * (ss * x[j]); out[j] = (1.0f + weight[j]) * (ss * x[j]);
@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
float* HWY_RESTRICT out, size_t size) { float* HWY_RESTRICT out, size_t size) {
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
float ss = SquaredL2(x, size); float ss = SquaredL2(x, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here // Note 1.0f centering here
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]); out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) { const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
float ss = SquaredL2(inout, size); float ss = SquaredL2(inout, size);
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps); ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
for (size_t j = 0; j < size; j++) { for (size_t j = 0; j < size; j++) {
// Note 1.0f centering here // Note 1.0f centering here
inout[j] = (1.0f + weight[j]) * (ss * inout[j]); inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
const float ss = SquaredL2(inout, size); const float ss = SquaredL2(inout, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps)); const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * N32) {
@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size); const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps)); const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * N32) {
@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
constexpr float eps = 1e-6f; constexpr float eps = 1e-6f;
const float ss = SquaredL2(x, size); const float ss = SquaredL2(x, size);
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps)); const VF vss =
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0); HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
for (size_t i = 0; i < size; i += 2 * N32) { for (size_t i = 0; i < size; i += 2 * N32) {
@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
const size_t num_timescales = dim_model / 2; const size_t num_timescales = dim_model / 2;
const float log_timescale_increment = const float log_timescale_increment =
logf(10000.0f) / logf(10000.0f) /
(num_timescales != 0 (num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
? static_cast<float>(static_cast<int>(num_timescales) - 1)
: 1.0f);
for (size_t dim = 0; dim < num_timescales; ++dim) { for (size_t dim = 0; dim < num_timescales; ++dim) {
const float inv_timescale = const float inv_timescale =
expf(static_cast<int>(dim) * -log_timescale_increment); expf(StaticCast<float>(dim) * -log_timescale_increment);
x[dim] += sinf(pos * inv_timescale); x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
x[num_timescales + dim] += cosf(pos * inv_timescale); x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
} }
} }
@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) { for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) / const float freq_exponents =
static_cast<float>(dim_qkv); StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents); const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale; const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta); const float cos_val = cosf(theta);
const float sin_val = sinf(theta); const float sin_val = sinf(theta);
const float x0 = x[dim]; const float x0 = x[dim];
@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
HWY_DASSERT(dim_qkv % 2 == 0); HWY_DASSERT(dim_qkv % 2 == 0);
const size_t half_dim_qkv = dim_qkv / 2; const size_t half_dim_qkv = dim_qkv / 2;
for (size_t dim = 0; dim < half_dim_qkv; ++dim) { for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) / const float freq_exponents =
static_cast<float>(dim_qkv); StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably. // Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
const float timescale = powf(10000.0f, freq_exponents); const float timescale = powf(10000.0f, freq_exponents);
const float theta = pos / timescale; const float theta = StaticCast<float>(pos) / timescale;
const float cos_val = cosf(theta); const float cos_val = cosf(theta);
const float sin_val = sinf(theta); const float sin_val = sinf(theta);
const float x0 = x[dim]; const float x0 = x[dim];
@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1] std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
std::array<int, k> indices{}; std::array<int, k> indices{};
for (size_t i = 0; i < vocab_size; ++i) { for (size_t i = 0; i < vocab_size; ++i) {
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) { if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
continue; continue;
} }
for (size_t j = 0; j < k; ++j) { for (size_t j = 0; j < k; ++j) {
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) { if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
// shift elements by 1, insert the new value, move on to next value // shift elements by 1, insert the new value, move on to next value
for (size_t idx = k - 1; idx > j; --idx) { for (size_t idx = k - 1; idx > j; --idx) {
top_k[idx] = top_k[idx - 1]; top_k[idx] = top_k[idx - 1];
indices[idx] = indices[idx - 1]; indices[idx] = indices[idx - 1];
} }
top_k[j] = probabilities[i]; top_k[j] = probabilities[i];
indices[j] = static_cast<int>(i); indices[j] = StaticCast<int>(i);
break; break;
} }
} }