mirror of https://github.com/google/gemma.cpp.git
[WIP] dev/examples branch merge
This commit is contained in:
commit
5b9d8a9936
|
|
@ -1 +1,2 @@
|
|||
Language: Cpp
|
||||
BasedOnStyle: Google
|
||||
|
|
|
|||
|
|
@ -0,0 +1,206 @@
|
|||
FormatStyle: file
|
||||
Checks: "-*,\
|
||||
abseil-*,\
|
||||
-abseil-string-find-startswith,\
|
||||
-abseil-string-find-str-contains,\
|
||||
bugprone-*,\
|
||||
-bugprone-argument-comment,\
|
||||
-bugprone-assert-side-effect,\
|
||||
-bugprone-bad-signal-to-kill-thread,\
|
||||
-bugprone-bool-pointer-implicit-conversion,\
|
||||
-bugprone-branch-clone,\
|
||||
-bugprone-copy-constructor-init,\
|
||||
-bugprone-dangling-handle,\
|
||||
-bugprone-dynamic-static-initializers,\
|
||||
-bugprone-easily-swappable-parameters,\
|
||||
-bugprone-exception-escape,\
|
||||
-bugprone-fold-init-type,\
|
||||
-bugprone-forward-declaration-namespace,\
|
||||
-bugprone-forwarding-reference-overload,\
|
||||
-bugprone-implicit-widening-of-multiplication-result,\
|
||||
-bugprone-inaccurate-erase,\
|
||||
-bugprone-incorrect-roundings,\
|
||||
-bugprone-infinite-loop,\
|
||||
-bugprone-integer-division,\
|
||||
-bugprone-lambda-function-name,\
|
||||
-bugprone-macro-parentheses,\
|
||||
-bugprone-macro-repeated-side-effects,\
|
||||
-bugprone-misplaced-operator-in-strlen-in-alloc,\
|
||||
-bugprone-misplaced-widening-cast,\
|
||||
-bugprone-move-forwarding-reference,\
|
||||
-bugprone-multiple-statement-macro,\
|
||||
-bugprone-narrowing-conversions,\
|
||||
-bugprone-no-escape,\
|
||||
-bugprone-not-null-terminated-result,\
|
||||
-bugprone-parent-virtual-call,\
|
||||
-bugprone-posix-return,\
|
||||
-bugprone-redundant-branch-condition,\
|
||||
-bugprone-reserved-identifier,\
|
||||
-bugprone-signal-handler,\
|
||||
-bugprone-signed-char-misuse,\
|
||||
-bugprone-sizeof-container,\
|
||||
-bugprone-sizeof-expression,\
|
||||
-bugprone-spuriously-wake-up-functions,\
|
||||
-bugprone-string-constructor,\
|
||||
-bugprone-string-integer-assignment,\
|
||||
-bugprone-string-literal-with-embedded-nul,\
|
||||
-bugprone-stringview-nullptr,\
|
||||
-bugprone-suspicious-enum-usage,\
|
||||
-bugprone-suspicious-include,\
|
||||
-bugprone-suspicious-memory-comparison,\
|
||||
-bugprone-suspicious-memset-usage,\
|
||||
-bugprone-suspicious-missing-comma,\
|
||||
-bugprone-suspicious-semicolon,\
|
||||
-bugprone-suspicious-string-compare,\
|
||||
-bugprone-swapped-arguments,\
|
||||
-bugprone-terminating-continue,\
|
||||
-bugprone-throw-keyword-missing,\
|
||||
-bugprone-too-small-loop-variable,\
|
||||
-bugprone-undefined-memory-manipulation,\
|
||||
-bugprone-undelegated-constructor,\
|
||||
-bugprone-unhandled-exception-at-new,\
|
||||
-bugprone-unhandled-self-assignment,\
|
||||
-bugprone-unused-raii,\
|
||||
-bugprone-unused-return-value,\
|
||||
-bugprone-use-after-move,\
|
||||
-bugprone-virtual-near-miss,\
|
||||
cert-*,\
|
||||
-cert-dcl16-c,\
|
||||
-cert-dcl21-cpp,\
|
||||
-cert-dcl37-c,\
|
||||
-cert-dcl50-cpp,\
|
||||
-cert-dcl51-cpp,\
|
||||
-cert-dcl54-cpp,\
|
||||
-cert-dcl58-cpp,\
|
||||
-cert-err33-c,\
|
||||
-cert-msc30-c,\
|
||||
-cert-msc32-c,\
|
||||
-cert-msc50-cpp,\
|
||||
-cert-msc51-cpp,\
|
||||
-cert-oop54-cpp,\
|
||||
-cert-str34-c,\
|
||||
-cert-str34-c,\
|
||||
-cert-str34-c,\
|
||||
-cert-str34-c,\
|
||||
-clang-analyzer-*,\
|
||||
concurrency-*,\
|
||||
-concurrency-mt-unsafe,\
|
||||
cppcoreguidelines-*,\
|
||||
-concurrency-mt-unsafe,\
|
||||
-cppcoreguidelines-avoid-c-arrays,\
|
||||
-cppcoreguidelines-avoid-const-or-ref-data-members,\
|
||||
-cppcoreguidelines-avoid-goto,\
|
||||
-cppcoreguidelines-avoid-magic-numbers,\
|
||||
-cppcoreguidelines-avoid-non-const-global-variables,\
|
||||
-cppcoreguidelines-c-copy-assignment-signature,\
|
||||
-cppcoreguidelines-explicit-virtual-functions,\
|
||||
-cppcoreguidelines-init-variables,\
|
||||
-cppcoreguidelines-interfaces-global-init,\
|
||||
-cppcoreguidelines-macro-usage,\
|
||||
-cppcoreguidelines-narrowing-conversions,\
|
||||
-cppcoreguidelines-no-malloc,\
|
||||
-cppcoreguidelines-non-private-member-variables-in-classes,\
|
||||
-cppcoreguidelines-owning-memory,\
|
||||
-cppcoreguidelines-prefer-member-initializer,\
|
||||
-cppcoreguidelines-pro-bounds-array-to-pointer-decay,\
|
||||
-cppcoreguidelines-pro-bounds-constant-array-index,\
|
||||
-cppcoreguidelines-pro-bounds-pointer-arithmetic,\
|
||||
-cppcoreguidelines-pro-type-const-cast,\
|
||||
-cppcoreguidelines-pro-type-member-init,\
|
||||
-cppcoreguidelines-pro-type-reinterpret-cast,\
|
||||
-cppcoreguidelines-pro-type-static-cast-downcast,\
|
||||
-cppcoreguidelines-pro-type-union-access,\
|
||||
-cppcoreguidelines-pro-type-vararg,\
|
||||
-cppcoreguidelines-slicing,\
|
||||
-cppcoreguidelines-special-member-functions,\
|
||||
-cppcoreguidelines-virtual-class-destructor,\
|
||||
google-*,\
|
||||
-google-default-arguments,\
|
||||
-google-explicit-constructor,\
|
||||
-google-readability-avoid-underscore-in-googletest-name,\
|
||||
-google-readability-braces-around-statements,\
|
||||
-google-readability-casting,\
|
||||
-google-readability-namespace-comments,\
|
||||
-google-readability-todo,\
|
||||
-google-runtime-int,\
|
||||
-google-upgrade-googletest-case,\
|
||||
misc-*,\
|
||||
-misc-misplaced-const,\
|
||||
-misc-new-delete-overloads,\
|
||||
-misc-non-private-member-variables-in-classes,\
|
||||
-misc-no-recursion,\
|
||||
-misc-redundant-expression,\
|
||||
-misc-uniqueptr-reset-release,\
|
||||
-misc-unconventional-assign-operator,\
|
||||
-misc-unused-parameters,\
|
||||
-misc-unused-using-decls,\
|
||||
modernize-*,\
|
||||
-modernize-avoid-c-arrays,\
|
||||
-modernize-concat-nested-namespaces,\
|
||||
-modernize-deprecated-headers,\
|
||||
-modernize-loop-convert,\
|
||||
-modernize-macro-to-enum,\
|
||||
-modernize-make-unique,\
|
||||
-modernize-pass-by-value,\
|
||||
-modernize-raw-string-literal,\
|
||||
-modernize-redundant-void-arg,\
|
||||
-modernize-return-braced-init-list,\
|
||||
-modernize-unary-static-assert,\
|
||||
-modernize-use-auto,\
|
||||
-modernize-use-bool-literals,\
|
||||
-modernize-use-default-member-init,\
|
||||
-modernize-use-emplace,\
|
||||
-modernize-use-equals-default,\
|
||||
-modernize-use-equals-delete,\
|
||||
-modernize-use-nodiscard,\
|
||||
-modernize-use-nullptr,\
|
||||
-modernize-use-override,\
|
||||
-modernize-use-trailing-return-type,\
|
||||
-modernize-use-transparent-functors,\
|
||||
-modernize-use-using,\
|
||||
performance-*,\
|
||||
-performance-faster-string-find,\
|
||||
-performance-for-range-copy,\
|
||||
-performance-inefficient-algorithm,\
|
||||
-performance-inefficient-string-concatenation,\
|
||||
-performance-inefficient-vector-operation,\
|
||||
-performance-move-const-arg,\
|
||||
-performance-no-automatic-move,\
|
||||
-performance-noexcept-move-constructor,\
|
||||
-performance-no-int-to-ptr,\
|
||||
-performance-trivially-destructible,\
|
||||
-performance-unnecessary-copy-initialization,\
|
||||
-performance-unnecessary-value-param,\
|
||||
portability-*,\
|
||||
readability-*,\
|
||||
-readability-avoid-const-params-in-decls,\
|
||||
-readability-braces-around-statements,\
|
||||
-readability-const-return-type,\
|
||||
-readability-container-data-pointer,\
|
||||
-readability-container-size-empty,\
|
||||
-readability-convert-member-functions-to-static,\
|
||||
-readability-else-after-return,\
|
||||
-readability-function-cognitive-complexity,\
|
||||
-readability-identifier-length,\
|
||||
-readability-implicit-bool-conversion,\
|
||||
-readability-inconsistent-declaration-parameter-name,\
|
||||
-readability-isolate-declaration,\
|
||||
-readability-magic-numbers,\
|
||||
-readability-make-member-function-const,\
|
||||
-readability-named-parameter,\
|
||||
-readability-non-const-parameter,\
|
||||
-readability-qualified-auto,\
|
||||
-readability-redundant-access-specifiers,\
|
||||
-readability-redundant-control-flow,\
|
||||
-readability-redundant-declaration,\
|
||||
-readability-redundant-member-init,\
|
||||
-readability-redundant-smartptr-get,\
|
||||
-readability-redundant-string-cstr,\
|
||||
-readability-redundant-string-init,\
|
||||
-readability-simplify-boolean-expr,\
|
||||
-readability-static-accessed-through-instance,\
|
||||
-readability-static-definition-in-anonymous-namespace,\
|
||||
-readability-suspicious-call-argument,\
|
||||
-readability-uppercase-literal-suffix,\
|
||||
-readability-use-anyofallof
|
||||
"
|
||||
|
|
@ -12,6 +12,7 @@ jobs:
|
|||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
# When adding another, also add to copybara's github_check_runs.
|
||||
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
|
||||
build_type: ['Release']
|
||||
preset: ['make', 'windows']
|
||||
|
|
@ -43,7 +44,7 @@ jobs:
|
|||
-D CMAKE_CXX_COMPILER_LAUNCHER=ccache
|
||||
|
||||
- name: Build
|
||||
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }}
|
||||
run: cmake --build ${{ github.workspace }}/build --preset ${{ matrix.preset }} --config ${{ matrix.build_type }} -j 4
|
||||
|
||||
- name: Archive production artifacts
|
||||
uses: actions/upload-artifact@v4
|
||||
|
|
@ -54,3 +55,21 @@ jobs:
|
|||
${{ github.workspace }}/build/${{ matrix.build_type }}/libgemma.lib
|
||||
${{ github.workspace }}/build/gemma
|
||||
${{ github.workspace }}/build/libgemma.a
|
||||
|
||||
bazel:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Harden Runner
|
||||
uses: step-security/harden-runner@63c24ba6bd7ba022e95695ff85de572c04a18142 # v2.7.0
|
||||
with:
|
||||
egress-policy: audit # cannot be block - runner does git checkout
|
||||
|
||||
- uses: actions/checkout@8ade135a41bc03ea155e62e844d188df1ea18608 # v4.0.0
|
||||
|
||||
- uses: bazelbuild/setup-bazelisk@b39c379c82683a5f25d34f0d062761f62693e0b2 # v3.0.0
|
||||
|
||||
- uses: actions/cache@ab5e6d0c87105b4c9c2047343972218f562e4319 # v4.0.1
|
||||
with:
|
||||
path: ~/.cache/bazel
|
||||
key: bazel-${{ runner.os }}
|
||||
- run: bazel build -c opt --cxxopt=-std=c++20 //...
|
||||
59
BUILD.bazel
59
BUILD.bazel
|
|
@ -25,21 +25,14 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:algo",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:dot",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:math",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:matvec",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"//hwy/contrib/sort:vqsort",
|
||||
"@hwy//:algo",
|
||||
"@hwy//:dot",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:math",
|
||||
"@hwy//:matvec",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -49,8 +42,7 @@ cc_library(
|
|||
"util/args.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -61,8 +53,7 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":args",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -78,19 +69,13 @@ cc_library(
|
|||
deps = [
|
||||
":args",
|
||||
":transformer_ops",
|
||||
"//base",
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:matvec",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark", # timer
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
":sentencepiece_processor",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:matvec",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -104,13 +89,9 @@ cc_binary(
|
|||
":args",
|
||||
":gemma_lib",
|
||||
"//compression:compress",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:profiler",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:profiler",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ project(gemma)
|
|||
|
||||
set(CMAKE_CXX_STANDARD 17)
|
||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||
|
||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG da250571a45826b21eebbddc1e50d0c1137dee5f)
|
||||
FetchContent_MakeAvailable(highway)
|
||||
|
|
|
|||
|
|
@ -127,13 +127,13 @@ working with weights, kv cache and activations (e.g. you might have multiple kv
|
|||
caches and activations for a single set of weights) more directly rather than
|
||||
only using a Gemma object.
|
||||
|
||||
## Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
|
||||
### Use the tokenizer in the Gemma object (or interact with the Tokenizer object directly)
|
||||
|
||||
You pretty much only do things with the tokenizer, call `Encode()` to go from
|
||||
string prompts to token id vectors, or `Decode()` to go from token id vector
|
||||
outputs from the model back to strings.
|
||||
|
||||
## The main entrypoint for generation is `GenerateGemma()`
|
||||
### The main entrypoint for generation is `GenerateGemma()`
|
||||
|
||||
Calling into `GenerateGemma` with a tokenized prompt will 1) mutate the
|
||||
activation values in `model` and 2) invoke StreamFunc - a lambda callback for
|
||||
|
|
@ -150,7 +150,7 @@ constrained decoding type of use cases where you want to force the generation
|
|||
to fit a grammar. If you're not doing this, you can send an empty lambda as a
|
||||
no-op which is what `run.cc` does.
|
||||
|
||||
## If you want to invoke the neural network forward function directly call the `Transformer()` function
|
||||
### If you want to invoke the neural network forward function directly call the `Transformer()` function
|
||||
|
||||
For high-level applications, you might only call `GenerateGemma()` and never
|
||||
interact directly with the neural network, but if you're doing something a bit
|
||||
|
|
@ -158,11 +158,20 @@ more custom you can call transformer which performs a single inference
|
|||
operation on a single token and mutates the Activations and the KVCache through
|
||||
the neural network computation.
|
||||
|
||||
## For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
### For low level operations, defining new architectures, call `ops.h` functions directly
|
||||
|
||||
You use `ops.h` if you're writing other NN architectures or modifying the
|
||||
inference path of the Gemma model.
|
||||
|
||||
## Building with Bazel
|
||||
|
||||
The sentencepiece library we depend on requires some additional work to build
|
||||
with the Bazel build system. First, it does not export its BUILD file, so we
|
||||
provide `bazel/sentencepiece.bazel`. Second, it ships with a vendored subset of
|
||||
the Abseil library. `bazel/com_google_sentencepiece.patch` changes the code to
|
||||
support Abseil as a standalone dependency without third_party/ prefixes, similar
|
||||
to the transforms we apply to Gemma via Copybara.
|
||||
|
||||
## Discord
|
||||
|
||||
We're also trying out a discord server for discussion here -
|
||||
|
|
|
|||
55
MODULE.bazel
55
MODULE.bazel
|
|
@ -3,12 +3,57 @@ module(
|
|||
version = "0.1.0",
|
||||
)
|
||||
|
||||
bazel_dep(
|
||||
name = "rules_license",
|
||||
version = "0.0.7",
|
||||
bazel_dep(name = "rules_license", version = "0.0.7")
|
||||
bazel_dep(name = "googletest", version = "1.14.0")
|
||||
|
||||
# Copied from Highway because Bazel does not load them transitively
|
||||
bazel_dep(name = "bazel_skylib", version = "1.4.1")
|
||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||
bazel_dep(name = "platforms", version = "0.0.7")
|
||||
|
||||
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
|
||||
http_archive(
|
||||
name = "hwy",
|
||||
urls = ["https://github.com/google/highway/archive/refs/tags/1.1.0.zip"],
|
||||
integrity = "sha256-zkJX2SwL4wQ0nHMsURW7MDLEf43vFSnqhSUsUM6eQmY=",
|
||||
strip_prefix = "highway-1.1.0",
|
||||
)
|
||||
|
||||
bazel_dep(
|
||||
http_archive(
|
||||
name = "com_google_sentencepiece",
|
||||
version = "0.1.96",
|
||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||
strip_prefix = "sentencepiece-0.1.96",
|
||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||
build_file = "@//bazel:sentencepiece.bazel",
|
||||
patches = ["@//bazel:com_google_sentencepiece.patch"],
|
||||
patch_args = ["-p1"],
|
||||
)
|
||||
|
||||
# For sentencepiece
|
||||
http_archive(
|
||||
name = "darts_clone",
|
||||
build_file_content = """
|
||||
licenses(["notice"])
|
||||
exports_files(["LICENSE"])
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
cc_library(
|
||||
name = "darts_clone",
|
||||
hdrs = [
|
||||
"include/darts.h",
|
||||
],
|
||||
)
|
||||
""",
|
||||
sha256 = "c97f55d05c98da6fcaf7f9ecc6a6dc6bc5b18b8564465f77abff8879d446491c",
|
||||
strip_prefix = "darts-clone-e40ce4627526985a7767444b6ed6893ab6ff8983",
|
||||
urls = [
|
||||
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||
],
|
||||
)
|
||||
# ABSL on 2023-10-18
|
||||
http_archive(
|
||||
name = "com_google_absl",
|
||||
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
|
||||
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
|
||||
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
|
||||
)
|
||||
|
|
|
|||
25
README.md
25
README.md
|
|
@ -65,15 +65,26 @@ winget install --id Kitware.CMake
|
|||
winget install --id Microsoft.VisualStudio.2022.BuildTools --force --override "--passive --wait --add Microsoft.VisualStudio.Workload.VCTools;installRecommended --add Microsoft.VisualStudio.Component.VC.Llvm.Clang --add Microsoft.VisualStudio.Component.VC.Llvm.ClangToolset"
|
||||
```
|
||||
|
||||
### Step 1: Obtain model weights and tokenizer from Kaggle
|
||||
### Step 1: Obtain model weights and tokenizer from Kaggle or Hugging Face Hub
|
||||
|
||||
Visit [the Gemma model page on
|
||||
Kaggle](https://www.kaggle.com/models/google/gemma) and select `Model Variations
|
||||
Kaggle](https://www.kaggle.com/models/google/gemma/frameworks/gemmaCpp) and select `Model Variations
|
||||
|> Gemma C++`. On this tab, the `Variation` dropdown includes the options below.
|
||||
Note bfloat16 weights are higher fidelity, while 8-bit switched floating point
|
||||
weights enable faster inference. In general, we recommend starting with the
|
||||
`-sfp` checkpoints.
|
||||
|
||||
Alternatively, visit the [gemma.cpp](https://huggingface.co/models?other=gemma.cpp)
|
||||
models on the Hugging Face Hub. First go the the model repository of the model of interest
|
||||
(see recommendations below). Then, click the `Files and versions` tab and download the
|
||||
model and tokenizer files. For programmatic downloading, if you have `huggingface_hub`
|
||||
installed, you can also download by running:
|
||||
|
||||
```
|
||||
huggingface-cli login # Just the first time
|
||||
huggingface-cli download google/gemma-2b-sfp-cpp --local-dir build/
|
||||
```
|
||||
|
||||
2B instruction-tuned (`it`) and pre-trained (`pt`) models:
|
||||
|
||||
| Model name | Description |
|
||||
|
|
@ -98,6 +109,8 @@ weights enable faster inference. In general, we recommend starting with the
|
|||
|
||||
### Step 2: Extract Files
|
||||
|
||||
If you downloaded the models from Hugging Face, skip to step 3.
|
||||
|
||||
After filling out the consent form, the download should proceed to retrieve a
|
||||
tar archive file `archive.tar.gz`. Extract files from `archive.tar.gz` (this can
|
||||
take a few minutes):
|
||||
|
|
@ -175,6 +188,14 @@ cmake --build --preset windows -j [number of parallel threads to use]
|
|||
|
||||
If the build is successful, you should now have a `gemma.exe` executable in the `build/` directory.
|
||||
|
||||
#### Bazel
|
||||
|
||||
```sh
|
||||
bazel build -c opt --cxxopt=-std=c++20 :gemma
|
||||
```
|
||||
|
||||
If the build is successful, you should now have a `gemma` executable in the `bazel-bin/` directory.
|
||||
|
||||
### Step 4: Run
|
||||
|
||||
You can now run `gemma` from inside the `build/` directory.
|
||||
|
|
|
|||
24
WORKSPACE
24
WORKSPACE
|
|
@ -1,24 +1,4 @@
|
|||
workspace(name = "gemma")
|
||||
|
||||
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||
load("@bazel_tools//tools/build_defs/repo:utils.bzl", "maybe")
|
||||
|
||||
maybe(
|
||||
http_archive,
|
||||
name = "rules_license",
|
||||
sha256 = "4531deccb913639c30e5c7512a054d5d875698daeb75d8cf90f284375fe7c360",
|
||||
urls = [
|
||||
"https://github.com/bazelbuild/rules_license/releases/download/0.0.7/rules_license-0.0.7.tar.gz",
|
||||
],
|
||||
)
|
||||
|
||||
maybe(
|
||||
http_archive,
|
||||
name = "com_google_sentencepiece",
|
||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||
strip_prefix = "sentencepiece-0.1.96",
|
||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||
build_file = "@//third_party:sentencepiece.bazel",
|
||||
patches = ["@//third_party:com_google_sentencepiece.patch"],
|
||||
patch_args = ["-p1"],
|
||||
)
|
||||
# This file marks the root of the Bazel workspace.
|
||||
# See MODULE.bazel for external dependencies setup.
|
||||
|
|
|
|||
|
|
@ -0,0 +1,4 @@
|
|||
package(
|
||||
default_applicable_licenses = ["//:license"],
|
||||
default_visibility = ["//:__subpackages__"],
|
||||
)
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -0,0 +1,97 @@
|
|||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
features = [
|
||||
"layering_check",
|
||||
"parse_headers",
|
||||
],
|
||||
)
|
||||
|
||||
licenses(["notice"]) # Apache 2, BSD, MIT
|
||||
|
||||
proto_library(
|
||||
name = "sentencepiece_proto",
|
||||
srcs = ["src/sentencepiece.proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "sentencepiece_cc_proto",
|
||||
deps = [":sentencepiece_proto"],
|
||||
)
|
||||
|
||||
proto_library(
|
||||
name = "sentencepiece_model_proto",
|
||||
srcs = ["src/sentencepiece_model.proto"],
|
||||
)
|
||||
|
||||
cc_proto_library(
|
||||
name = "sentencepiece_model_cc_proto",
|
||||
deps = [":sentencepiece_model_proto"],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "config_h",
|
||||
srcs = ["config.h.in"],
|
||||
outs = ["config.h"],
|
||||
cmd = "cp $< $@",
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
hdrs = [
|
||||
"config.h",
|
||||
"src/common.h",
|
||||
],
|
||||
deps = [
|
||||
"@com_google_absl//absl/base",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "sentencepiece_processor",
|
||||
srcs = [
|
||||
"src/bpe_model.cc",
|
||||
"src/char_model.cc",
|
||||
"src/error.cc",
|
||||
"src/filesystem.cc",
|
||||
"src/model_factory.cc",
|
||||
"src/model_interface.cc",
|
||||
"src/normalizer.cc",
|
||||
"src/sentencepiece_processor.cc",
|
||||
"src/unigram_model.cc",
|
||||
"src/util.cc",
|
||||
"src/word_model.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"src/bpe_model.h",
|
||||
"src/char_model.h",
|
||||
"src/filesystem.h",
|
||||
"src/freelist.h",
|
||||
"src/model_factory.h",
|
||||
"src/model_interface.h",
|
||||
"src/normalizer.h",
|
||||
"src/sentencepiece_processor.h",
|
||||
"src/trainer_interface.h",
|
||||
"src/unigram_model.h",
|
||||
"src/util.h",
|
||||
"src/word_model.h",
|
||||
],
|
||||
defines = ["_USE_TF_STRING_VIEW"],
|
||||
includes = [
|
||||
".",
|
||||
"src",
|
||||
],
|
||||
linkstatic = 1,
|
||||
deps =
|
||||
[
|
||||
":common",
|
||||
":sentencepiece_cc_proto",
|
||||
":sentencepiece_model_cc_proto",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/memory",
|
||||
"@com_google_absl//absl/status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/strings:str_format",
|
||||
"@darts_clone",
|
||||
],
|
||||
)
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
# Weight compression, I/O and analysis
|
||||
|
||||
package(
|
||||
default_applicable_licenses = ["//third_party/gemma_cpp:license"],
|
||||
default_applicable_licenses = ["//:license"],
|
||||
default_visibility = [
|
||||
"//learning/gemini/prod/contrib/gemini_cpp:__subpackages__",
|
||||
"//third_party/gemma_cpp:__subpackages__",
|
||||
"//:__subpackages__",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -17,10 +17,8 @@ cc_library(
|
|||
"blob_store.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -34,8 +32,7 @@ cc_library(
|
|||
"stats.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -48,8 +45,7 @@ cc_library(
|
|||
"sfp-inl.h",
|
||||
],
|
||||
deps = [
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"@hwy//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -65,15 +61,11 @@ cc_test(
|
|||
deps = [
|
||||
":sfp",
|
||||
":stats",
|
||||
"//testing/base/public:gunit_main_no_google3",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy_test_util",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -87,9 +79,8 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":sfp",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
"//third_party/highway/hwy/contrib/sort:vqsort",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -106,13 +97,10 @@ cc_test(
|
|||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
"//testing/base/public:gunit_main_no_google3",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy_test_util",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark",
|
||||
"@googletest//:gtest_main",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -131,12 +119,9 @@ cc_library(
|
|||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:dot",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"@hwy//:dot",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -150,12 +135,9 @@ cc_library(
|
|||
":nuq",
|
||||
":sfp",
|
||||
":stats",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:hwy",
|
||||
# copybara:import_next_line:hwy
|
||||
"//:nanobenchmark", # timer
|
||||
# copybara:import_next_line:hwy
|
||||
"//:thread_pool",
|
||||
"//third_party/highway/hwy/contrib/sort:vqsort",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:nanobenchmark", # timer
|
||||
"@hwy//:thread_pool",
|
||||
"@hwy//hwy/contrib/sort:vqsort",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -149,7 +149,7 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VF in0 = hn::LoadN(df, in + i, remaining);
|
||||
const size_t remaining1 = remaining - HWY_MIN(remaining, N / 2);
|
||||
|
|
@ -195,7 +195,7 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
}
|
||||
}
|
||||
|
||||
size_t remaining = num - i;
|
||||
const size_t remaining = num - i;
|
||||
if (remaining != 0) {
|
||||
const VBF in16 = hn::LoadN(dbf, in + in_ofs + i, remaining);
|
||||
const VF in0 = hn::PromoteLowerTo(df, in16);
|
||||
|
|
@ -287,7 +287,7 @@ struct CompressTraits<NuqStream> {
|
|||
|
||||
if (COMPRESS_STATS) {
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
tls.stats.NotifyIn(in[i] * 100 + 500);
|
||||
tls.stats.NotifyIn(static_cast<int>(lroundf(in[i] * 100.0f + 500.0f)));
|
||||
}
|
||||
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
|
||||
|
|
@ -358,7 +358,7 @@ HWY_NOINLINE void Compress(const float* in, size_t num,
|
|||
});
|
||||
|
||||
const double t1 = hwy::platform::Now();
|
||||
const double mb = num * sizeof(in[0]) * 1E-6;
|
||||
const double mb = static_cast<double>(num) * sizeof(in[0]) * 1E-6;
|
||||
const double mbps = mb / (t1 - t0);
|
||||
fprintf(stderr, "Compress %.1f MB/s\n", mbps);
|
||||
|
||||
|
|
|
|||
|
|
@ -68,15 +68,15 @@ class DistortionStats {
|
|||
|
||||
double GeomeanValueDivL1() const {
|
||||
if (num_rel_ == 0) return 0.0;
|
||||
return exp(sum_log_rel_ / num_rel_);
|
||||
return exp(sum_log_rel_ / static_cast<double>(num_rel_));
|
||||
}
|
||||
|
||||
double PNorm() const {
|
||||
// p-norms are a compromise between max-norm (penalizes the largest error
|
||||
// without dilution, but does not notice any other errors) and L1 (all
|
||||
// errors contribute, but large errors are diluted by smaller ones).
|
||||
const double norm3 = pow(sum_pow3_ / n_, 1.0 / 3);
|
||||
const double norm6 = pow(sum_pow6_ / n_, 1.0 / 6);
|
||||
const double norm3 = pow(sum_pow3_ / static_cast<double>(n_), 1.0 / 3);
|
||||
const double norm6 = pow(sum_pow6_ / static_cast<double>(n_), 1.0 / 6);
|
||||
return 0.5 * (norm3 + norm6);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ class NuqClustering {
|
|||
|
||||
inv_len_[0] = 0.0f; // unused
|
||||
for (size_t i = 0; i <= kGroupSize; ++i) {
|
||||
inv_len_[i] = 1.0f / i;
|
||||
inv_len_[i] = 1.0f / static_cast<float>(i);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -229,7 +229,7 @@ class NuqClustering {
|
|||
const float sum = cc.SumOfSorted(start, last);
|
||||
const int size = static_cast<int>(last) - static_cast<int>(start) + 1;
|
||||
HWY_DASSERT(0 < size && size <= kGroupSize);
|
||||
centers[k] = sum / size;
|
||||
centers[k] = sum / static_cast<float>(size);
|
||||
|
||||
// We know the range inside sorted_and_i[]; translate to original indices,
|
||||
// which are stored inside each of the sorted_and_i mantissas.
|
||||
|
|
@ -470,9 +470,7 @@ class NuqCodec {
|
|||
static HWY_INLINE size_t Enc(DF df, const float* const in, const size_t num,
|
||||
ClusterBuf& buf, const size_t out_capacity,
|
||||
NuqStream* const out, const size_t out_ofs) {
|
||||
const hn::Repartition<uint8_t, DF> d8;
|
||||
const hn::Repartition<uint16_t, DF> d16;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
using V16 = hn::Vec<decltype(d16)>;
|
||||
|
||||
const size_t N16 = hn::Lanes(d16);
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// SFP uses ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
|
@ -23,9 +28,10 @@
|
|||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"third_party/gemma_cpp/compression/nuq_test.cc" // NOLINT
|
||||
#define HWY_TARGET_INCLUDE "compression/nuq_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Other headers that include Highway must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
|
|
|
|||
|
|
@ -13,6 +13,11 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// We use ConcatEven/Odd which are not supported. Use HWY_EMU128 instead.
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
#include "compression/sfp.h"
|
||||
|
||||
|
|
@ -27,9 +32,10 @@
|
|||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"third_party/gemma_cpp/compression/sfp_test.cc" // NOLINT
|
||||
#define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT
|
||||
// clang-format on
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
// Any highway.h must come after foreach_target.h
|
||||
// copybara:import_next_line:gemma_cpp
|
||||
|
|
@ -301,7 +307,7 @@ struct TestEncDec {
|
|||
for (size_t i = 0; i < num; ++i) {
|
||||
const float out = hwy::F32FromBF16(dec[i]);
|
||||
sum += hwy::ConvertScalarTo<double>(hwy::ScalarAbs(in[i]));
|
||||
stats.Notify(in[i], out);
|
||||
stats.Notify(hwy::ConvertScalarTo<float>(in[i]), out);
|
||||
}
|
||||
const double avg = sum / num;
|
||||
fprintf(stderr, "Avg magnitude %.3E, p-norm %.3E snr %.2f @%zu = %.4E\n",
|
||||
|
|
|
|||
19
gemma.cc
19
gemma.cc
|
|
@ -527,7 +527,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
||||
// always equal.
|
||||
size_t pos_offset = 0; // offset relative to pos
|
||||
double prefill_start = hwy::platform::Now();
|
||||
const double prefill_start = hwy::platform::Now();
|
||||
|
||||
// Prefill stops before prompt.size() - 1 since the last prompt token is the
|
||||
// first input token for generation.
|
||||
|
|
@ -549,12 +549,13 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
if (verbosity >= 2) {
|
||||
// in the future this output should not occur in GenerateImpl but instead
|
||||
// should be available as observable state for frontend code to handle I/O.
|
||||
double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec = pos_offset / (prefill_end - prefill_start);
|
||||
const double prefill_end = hwy::platform::Now();
|
||||
const double prefill_tok_sec =
|
||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
||||
}
|
||||
|
||||
double gen_start = hwy::platform::Now();
|
||||
const double gen_start = hwy::platform::Now();
|
||||
|
||||
HWY_DASSERT(pos_offset == prompt.size() - 1);
|
||||
|
||||
|
|
@ -592,10 +593,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
|||
}
|
||||
if (token == EOS_ID) {
|
||||
if (verbosity >= 2) {
|
||||
double gen_end = hwy::platform::Now();
|
||||
const double gen_end = hwy::platform::Now();
|
||||
const double gen_tok_sec =
|
||||
(pos_offset - pos_gen_start) / (gen_end - gen_start);
|
||||
std::cout << "[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||
static_cast<double>(pos_offset - pos_gen_start) /
|
||||
(gen_end - gen_start);
|
||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||
}
|
||||
break;
|
||||
}
|
||||
|
|
@ -693,7 +695,8 @@ hwy::AlignedFreeUniquePtr<uint8_t[]> GetCompressedWeights(
|
|||
if (loader.ReadAll(pool)) return c_weights_u8;
|
||||
|
||||
// Get weights, compress, and store in cache.
|
||||
hwy::AlignedUniquePtr<Weights<TConfig>> weights = LoadWeights<TConfig>(model);
|
||||
const hwy::AlignedUniquePtr<Weights<TConfig>> weights =
|
||||
LoadWeights<TConfig>(model);
|
||||
Compressor compressor(pool);
|
||||
ForEachTensor<TConfig>(weights.get(), *c_weights, compressor);
|
||||
compressor.WriteAll(pool, cache.path.c_str());
|
||||
|
|
|
|||
56
ops.h
56
ops.h
|
|
@ -57,6 +57,17 @@ HWY_INLINE constexpr size_t MaxCols() {
|
|||
return 2048;
|
||||
}
|
||||
|
||||
template <typename To, typename From>
|
||||
HWY_INLINE constexpr std::enable_if_t<
|
||||
std::is_arithmetic_v<To> && std::is_arithmetic_v<From>, To>
|
||||
StaticCast(From from) noexcept {
|
||||
if constexpr (std::is_unsigned_v<From> && std::is_floating_point_v<To>)
|
||||
return static_cast<To>(
|
||||
static_cast<hwy::SignedFromSize<sizeof(From)>>(from));
|
||||
else
|
||||
return static_cast<To>(from);
|
||||
}
|
||||
|
||||
template <size_t kOuter>
|
||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||
|
|
@ -230,7 +241,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void GeluMulToBF16(
|
|||
|
||||
size_t i = 0;
|
||||
if (size >= 2 * NF) {
|
||||
for (; i < size - 2 * NF; i += 2 * NF) {
|
||||
for (; i <= size - 2 * NF; i += 2 * NF) {
|
||||
const VF mul0 = hn::LoadU(df, mul + i);
|
||||
const VF mul1 = hn::LoadU(df, mul + i + NF);
|
||||
const VF g0 = hn::Mul(mul0, Gelu(df, hn::LoadU(df, gelu_in + i)));
|
||||
|
|
@ -341,7 +352,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + weight[j]) * (ss * x[j]);
|
||||
|
|
@ -353,7 +364,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
float* HWY_RESTRICT out, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(x, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
out[j] = (1.0f + hwy::F32FromBF16(weight[j])) * (ss * x[j]);
|
||||
|
|
@ -364,7 +375,7 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
const float* HWY_RESTRICT weight, float* HWY_RESTRICT inout, size_t size) {
|
||||
constexpr float eps = 1e-6f;
|
||||
float ss = SquaredL2(inout, size);
|
||||
ss = 1.0f / sqrtf(ss / static_cast<int>(size) + eps);
|
||||
ss = 1.0f / sqrtf(ss / StaticCast<float>(size) + eps);
|
||||
for (size_t j = 0; j < size; j++) {
|
||||
// Note 1.0f centering here
|
||||
inout[j] = (1.0f + weight[j]) * (ss * inout[j]);
|
||||
|
|
@ -383,7 +394,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormInplace(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(inout, size);
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
const VF vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -411,7 +423,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / static_cast<int>(size) + eps));
|
||||
const VF vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -438,7 +451,8 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNorm(
|
|||
|
||||
constexpr float eps = 1e-6f;
|
||||
const float ss = SquaredL2(x, size);
|
||||
const VF vss = hn::Set(df32, 1.0f / sqrtf(ss / size + eps));
|
||||
const VF vss =
|
||||
hn::Set(df32, 1.0f / sqrtf(ss / StaticCast<float>(size) + eps));
|
||||
|
||||
HWY_DASSERT(size % (2 * MaxLanes(df32)) == 0);
|
||||
for (size_t i = 0; i < size; i += 2 * N32) {
|
||||
|
|
@ -459,14 +473,12 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddAbsolutePositionalEmbeddings(
|
|||
const size_t num_timescales = dim_model / 2;
|
||||
const float log_timescale_increment =
|
||||
logf(10000.0f) /
|
||||
(num_timescales != 0
|
||||
? static_cast<float>(static_cast<int>(num_timescales) - 1)
|
||||
: 1.0f);
|
||||
(num_timescales != 0 ? StaticCast<float>(num_timescales - 1) : 1.0f);
|
||||
for (size_t dim = 0; dim < num_timescales; ++dim) {
|
||||
const float inv_timescale =
|
||||
expf(static_cast<int>(dim) * -log_timescale_increment);
|
||||
x[dim] += sinf(pos * inv_timescale);
|
||||
x[num_timescales + dim] += cosf(pos * inv_timescale);
|
||||
expf(StaticCast<float>(dim) * -log_timescale_increment);
|
||||
x[dim] += sinf(StaticCast<float>(pos) * inv_timescale);
|
||||
x[num_timescales + dim] += cosf(StaticCast<float>(pos) * inv_timescale);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -475,11 +487,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void Rope(float* HWY_RESTRICT x,
|
|||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
||||
static_cast<float>(dim_qkv);
|
||||
const float freq_exponents =
|
||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||
const float timescale = powf(10000.0f, freq_exponents);
|
||||
const float theta = pos / timescale;
|
||||
const float theta = StaticCast<float>(pos) / timescale;
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
|
|
@ -496,11 +508,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void RopeAndMulBy(const float mul,
|
|||
HWY_DASSERT(dim_qkv % 2 == 0);
|
||||
const size_t half_dim_qkv = dim_qkv / 2;
|
||||
for (size_t dim = 0; dim < half_dim_qkv; ++dim) {
|
||||
const float freq_exponents = static_cast<float>(2 * static_cast<int>(dim)) /
|
||||
static_cast<float>(dim_qkv);
|
||||
const float freq_exponents =
|
||||
StaticCast<float>(2 * dim) / StaticCast<float>(dim_qkv);
|
||||
// Replacing with expf(ln(1E4) * freq_exponents) changes results noticeably.
|
||||
const float timescale = powf(10000.0f, freq_exponents);
|
||||
const float theta = pos / timescale;
|
||||
const float theta = StaticCast<float>(pos) / timescale;
|
||||
const float cos_val = cosf(theta);
|
||||
const float sin_val = sinf(theta);
|
||||
const float x0 = x[dim];
|
||||
|
|
@ -674,18 +686,18 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED int SampleTopK(
|
|||
std::array<float, k> top_k{}; // sorted from highest [0], to lowest [k-1]
|
||||
std::array<int, k> indices{};
|
||||
for (size_t i = 0; i < vocab_size; ++i) {
|
||||
if (probabilities[i] < top_k[k - 1] && accept_token(static_cast<int>(i))) {
|
||||
if (probabilities[i] < top_k[k - 1] && accept_token(StaticCast<int>(i))) {
|
||||
continue;
|
||||
}
|
||||
for (size_t j = 0; j < k; ++j) {
|
||||
if (probabilities[i] > top_k[j] && accept_token(static_cast<int>(i))) {
|
||||
if (probabilities[i] > top_k[j] && accept_token(StaticCast<int>(i))) {
|
||||
// shift elements by 1, insert the new value, move on to next value
|
||||
for (size_t idx = k - 1; idx > j; --idx) {
|
||||
top_k[idx] = top_k[idx - 1];
|
||||
indices[idx] = indices[idx - 1];
|
||||
}
|
||||
top_k[j] = probabilities[i];
|
||||
indices[j] = static_cast<int>(i);
|
||||
indices[j] = StaticCast<int>(i);
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue