mirror of https://github.com/google/gemma.cpp.git
Minor cleanup, Windows+Bazel build fixes
add app.h comment compress-inl: remove unused typedef gemma-inl: add missing HWY_ATTR and cast separate sum-inl.h and basics.h headers replace more hwy::bfloat16_t with BF16 update include pragmas update dot_test thresholds update Highway version in Bazel for HWY_RCAST_ALIGNED fix PiperOrigin-RevId: 684464326
This commit is contained in:
parent
85958f5fd3
commit
6ab3ff5bde
174
BUILD.bazel
174
BUILD.bazel
|
|
@ -20,11 +20,19 @@ licenses(["notice"])
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "basics",
|
||||||
|
hdrs = ["util/basics.h"],
|
||||||
|
deps = [
|
||||||
|
"@highway//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "allocator",
|
name = "allocator",
|
||||||
hdrs = ["util/allocator.h"],
|
hdrs = ["util/allocator.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -32,9 +40,9 @@ cc_library(
|
||||||
name = "test_util",
|
name = "test_util",
|
||||||
hdrs = ["util/test_util.h"],
|
hdrs = ["util/test_util.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:stats",
|
"@highway//:stats",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -42,9 +50,9 @@ cc_library(
|
||||||
name = "threading",
|
name = "threading",
|
||||||
hdrs = ["util/threading.h"],
|
hdrs = ["util/threading.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@hwy//:topology",
|
"@highway//:topology",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -54,8 +62,8 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":threading",
|
":threading",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -66,6 +74,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
textual_hdrs = [
|
textual_hdrs = [
|
||||||
"ops/dot-inl.h",
|
"ops/dot-inl.h",
|
||||||
|
"ops/sum-inl.h",
|
||||||
"ops/fp_arith-inl.h",
|
"ops/fp_arith-inl.h",
|
||||||
"ops/matmul-inl.h",
|
"ops/matmul-inl.h",
|
||||||
"ops/matvec-inl.h",
|
"ops/matvec-inl.h",
|
||||||
|
|
@ -73,14 +82,15 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
|
":basics",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:algo",
|
"@highway//:algo",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:math",
|
"@highway//:math",
|
||||||
"@hwy//:matvec",
|
"@highway//:matvec",
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -102,12 +112,12 @@ cc_test(
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:test_util",
|
"//compression:test_util",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark", #buildcleaner: keep
|
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:stats",
|
"@highway//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -127,9 +137,9 @@ cc_test(
|
||||||
":test_util",
|
":test_util",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark", #buildcleaner: keep
|
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -145,9 +155,9 @@ cc_test(
|
||||||
":ops",
|
":ops",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -164,10 +174,10 @@ cc_test(
|
||||||
":threading",
|
":threading",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -180,8 +190,8 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy", # base.h
|
"@highway//:hwy", # base.h
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -194,10 +204,10 @@ cc_library(
|
||||||
":common",
|
":common",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:stats",
|
"@highway//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -208,9 +218,9 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:nanobenchmark", # timer
|
"@highway//:nanobenchmark", # timer
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@com_google_sentencepiece//:sentencepiece_processor",
|
"@com_google_sentencepiece//:sentencepiece_processor",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -221,7 +231,7 @@ cc_library(
|
||||||
hdrs = ["gemma/kv_cache.h"],
|
hdrs = ["gemma/kv_cache.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -268,6 +278,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
|
":basics",
|
||||||
":common",
|
":common",
|
||||||
":ops",
|
":ops",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
|
|
@ -275,12 +286,13 @@ cc_library(
|
||||||
":weights",
|
":weights",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
|
"//compression:sfp",
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:bit_set",
|
"@highway//:bit_set",
|
||||||
"@hwy//:nanobenchmark", # timer
|
"@highway//:nanobenchmark", # timer
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -292,7 +304,7 @@ cc_library(
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -301,7 +313,7 @@ cc_library(
|
||||||
hdrs = ["util/args.h"],
|
hdrs = ["util/args.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -314,9 +326,9 @@ cc_library(
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@hwy//:topology",
|
"@highway//:topology",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -333,12 +345,12 @@ cc_library(
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":threading",
|
":threading",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
"@benchmark//:benchmark",
|
"@google_benchmark//:benchmark",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@hwy//:topology",
|
"@highway//:topology",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -357,8 +369,8 @@ cc_test(
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -374,9 +386,9 @@ cc_binary(
|
||||||
":threading",
|
":threading",
|
||||||
# Placeholder for internal dep, do not remove.,
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -391,9 +403,9 @@ cc_binary(
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@nlohmann_json//:json",
|
"@nlohmann_json//:json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -403,7 +415,7 @@ cc_binary(
|
||||||
srcs = ["evals/benchmarks.cc"],
|
srcs = ["evals/benchmarks.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
"@benchmark//:benchmark",
|
"@google_benchmark//:benchmark",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -418,8 +430,8 @@ cc_binary(
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@nlohmann_json//:json",
|
"@nlohmann_json//:json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -433,9 +445,9 @@ cc_binary(
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@nlohmann_json//:json",
|
"@nlohmann_json//:json",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -477,9 +489,9 @@ cc_library(
|
||||||
":prompt",
|
":prompt",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:dot",
|
"@highway//:dot",
|
||||||
"@hwy//:hwy", # base.h
|
"@highway//:hwy", # base.h
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -497,7 +509,7 @@ cc_library(
|
||||||
":prompt",
|
":prompt",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -517,7 +529,7 @@ cc_test(
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -544,9 +556,9 @@ cc_test(
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -559,8 +571,8 @@ cc_library(
|
||||||
":common",
|
":common",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -583,6 +595,6 @@ cc_test(
|
||||||
":threading",
|
":threading",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
|
||||||
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
||||||
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
||||||
|
|
||||||
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bb6c3f36b0c8dde8a8ef98b0f0884f4de820a7ca EXCLUDE_FROM_ALL)
|
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 0ca297227a373710e76dd45e0ad4d68adb6928fe EXCLUDE_FROM_ALL)
|
||||||
FetchContent_MakeAvailable(highway)
|
FetchContent_MakeAvailable(highway)
|
||||||
|
|
||||||
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
## Note: absl needs to be installed by sentencepiece. This will only happen if
|
||||||
|
|
@ -106,11 +106,13 @@ set(SOURCES
|
||||||
ops/matmul-inl.h
|
ops/matmul-inl.h
|
||||||
ops/matvec-inl.h
|
ops/matvec-inl.h
|
||||||
ops/ops-inl.h
|
ops/ops-inl.h
|
||||||
|
ops/sum-inl.h
|
||||||
paligemma/image.cc
|
paligemma/image.cc
|
||||||
paligemma/image.h
|
paligemma/image.h
|
||||||
util/allocator.h
|
util/allocator.h
|
||||||
util/app.h
|
util/app.h
|
||||||
util/args.h
|
util/args.h
|
||||||
|
util/basics.h
|
||||||
util/test_util.h
|
util/test_util.h
|
||||||
util/threading.h
|
util/threading.h
|
||||||
)
|
)
|
||||||
|
|
|
||||||
54
MODULE.bazel
54
MODULE.bazel
|
|
@ -3,37 +3,33 @@ module(
|
||||||
version = "0.1.0",
|
version = "0.1.0",
|
||||||
)
|
)
|
||||||
|
|
||||||
bazel_dep(name = "rules_license", version = "0.0.7")
|
bazel_dep(name = "abseil-cpp", version = "20240722.0")
|
||||||
bazel_dep(name = "googletest", version = "1.14.0")
|
bazel_dep(name = "bazel_skylib", version = "1.6.1")
|
||||||
|
bazel_dep(name = "googletest", version = "1.15.2")
|
||||||
# Copied from Highway because Bazel does not load them transitively
|
bazel_dep(name = "highway", version = "1.1.0")
|
||||||
bazel_dep(name = "bazel_skylib", version = "1.4.1")
|
bazel_dep(name = "nlohmann_json", version = "3.11.3")
|
||||||
|
bazel_dep(name = "platforms", version = "0.0.10")
|
||||||
bazel_dep(name = "rules_cc", version = "0.0.9")
|
bazel_dep(name = "rules_cc", version = "0.0.9")
|
||||||
bazel_dep(name = "platforms", version = "0.0.7")
|
bazel_dep(name = "rules_license", version = "0.0.7")
|
||||||
|
bazel_dep(name = "google_benchmark", version = "1.8.5")
|
||||||
|
|
||||||
|
# Require a more recent version.
|
||||||
|
git_override(
|
||||||
|
module_name = "highway",
|
||||||
|
commit = "0ca297227a373710e76dd45e0ad4d68adb6928fe",
|
||||||
|
remote = "https://github.com/google/highway",
|
||||||
|
)
|
||||||
|
|
||||||
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
|
||||||
|
|
||||||
http_archive(
|
|
||||||
name = "hwy",
|
|
||||||
urls = ["https://github.com/google/highway/archive/refs/tags/1.2.0.zip"],
|
|
||||||
integrity = "sha256-fbtKAGj5hhhBr5Bggtsrj4aIodC2OHb1njB8LGfom8A=", strip_prefix = "highway-1.2.0",
|
|
||||||
)
|
|
||||||
|
|
||||||
http_archive(
|
|
||||||
name = "nlohmann_json",
|
|
||||||
urls = ["https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip"],
|
|
||||||
integrity = "sha256-BAIrBdgG61/3MCPCgLaGl9Erk+G3JnoLIqGjnsdXgGk=",
|
|
||||||
strip_prefix = "json-3.11.3",
|
|
||||||
)
|
|
||||||
|
|
||||||
http_archive(
|
http_archive(
|
||||||
name = "com_google_sentencepiece",
|
name = "com_google_sentencepiece",
|
||||||
|
build_file = "@//bazel:sentencepiece.bazel",
|
||||||
|
patch_args = ["-p1"],
|
||||||
|
patches = ["@//bazel:sentencepiece.patch"],
|
||||||
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
|
||||||
strip_prefix = "sentencepiece-0.1.96",
|
strip_prefix = "sentencepiece-0.1.96",
|
||||||
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
|
||||||
build_file = "@//bazel:sentencepiece.bazel",
|
|
||||||
patches = ["@//bazel:sentencepiece.patch"],
|
|
||||||
patch_args = ["-p1"],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# For sentencepiece
|
# For sentencepiece
|
||||||
|
|
@ -56,17 +52,3 @@ cc_library(
|
||||||
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# ABSL on 2023-10-18
|
|
||||||
http_archive(
|
|
||||||
name = "com_google_absl",
|
|
||||||
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
|
|
||||||
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
|
|
||||||
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
|
|
||||||
)
|
|
||||||
# Benchmark
|
|
||||||
http_archive(
|
|
||||||
name = "benchmark",
|
|
||||||
urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.8.2.tar.gz"],
|
|
||||||
integrity = "sha256-KqspgNA3YTf5adkoSPu2gharsHYzA0U0/IxlzE56DpM=",
|
|
||||||
strip_prefix = "benchmark-1.8.2",
|
|
||||||
)
|
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ cc_library(
|
||||||
"src/common.h",
|
"src/common.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@com_google_absl//absl/base",
|
"@abseil-cpp//absl/base",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -86,12 +86,12 @@ cc_library(
|
||||||
":common",
|
":common",
|
||||||
":sentencepiece_cc_proto",
|
":sentencepiece_cc_proto",
|
||||||
":sentencepiece_model_cc_proto",
|
":sentencepiece_model_cc_proto",
|
||||||
"@com_google_absl//absl/container:flat_hash_map",
|
"@abseil-cpp//absl/container:flat_hash_map",
|
||||||
"@com_google_absl//absl/container:flat_hash_set",
|
"@abseil-cpp//absl/container:flat_hash_set",
|
||||||
"@com_google_absl//absl/memory",
|
"@abseil-cpp//absl/memory",
|
||||||
"@com_google_absl//absl/status",
|
"@abseil-cpp//absl/status",
|
||||||
"@com_google_absl//absl/strings",
|
"@abseil-cpp//absl/strings",
|
||||||
"@com_google_absl//absl/strings:str_format",
|
"@abseil-cpp//absl/strings:str_format",
|
||||||
"@darts_clone",
|
"@darts_clone",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
hdrs = ["io.h"],
|
hdrs = ["io.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
] + FILE_DEPS,
|
] + FILE_DEPS,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -43,8 +43,8 @@ cc_library(
|
||||||
hdrs = ["blob_store.h"],
|
hdrs = ["blob_store.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":io",
|
":io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -55,9 +55,9 @@ cc_library(
|
||||||
"shared.h",
|
"shared.h",
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:stats",
|
"@highway//:stats",
|
||||||
"@hwy//hwy/contrib/sort:vqsort",
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -69,8 +69,8 @@ cc_test(
|
||||||
":distortion",
|
":distortion",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark", # Unpredictable1
|
"@highway//:nanobenchmark", # Unpredictable1
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -79,7 +79,7 @@ cc_library(
|
||||||
hdrs = ["shared.h"],
|
hdrs = ["shared.h"],
|
||||||
textual_hdrs = ["sfp-inl.h"],
|
textual_hdrs = ["sfp-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -89,9 +89,9 @@ cc_library(
|
||||||
textual_hdrs = ["nuq-inl.h"],
|
textual_hdrs = ["nuq-inl.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":sfp",
|
":sfp",
|
||||||
"//:allocator",
|
"//:basics",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//hwy/contrib/sort:vqsort",
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -103,8 +103,8 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":compress",
|
":compress",
|
||||||
":distortion",
|
":distortion",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -122,9 +122,9 @@ cc_test(
|
||||||
":sfp",
|
":sfp",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -144,9 +144,9 @@ cc_test(
|
||||||
":sfp",
|
":sfp",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -164,11 +164,11 @@ cc_library(
|
||||||
":io",
|
":io",
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@hwy//:profiler",
|
"@highway//:profiler",
|
||||||
"@hwy//:stats",
|
"@highway//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -188,9 +188,9 @@ cc_test(
|
||||||
":test_util",
|
":test_util",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//:test_util",
|
"//:test_util",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -201,10 +201,10 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":nuq",
|
":nuq",
|
||||||
":sfp",
|
":sfp",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:stats",
|
"@highway//:stats",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
"@hwy//hwy/contrib/sort:vqsort",
|
"@highway//hwy/contrib/sort:vqsort",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -218,7 +218,7 @@ cc_binary(
|
||||||
"//:args",
|
"//:args",
|
||||||
"//:common",
|
"//:common",
|
||||||
"//:weights",
|
"//:weights",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@
|
||||||
#include <cmath> // lroundf, only if COMPRESS_STATS
|
#include <cmath> // lroundf, only if COMPRESS_STATS
|
||||||
|
|
||||||
#include "compression/blob_store.h"
|
#include "compression/blob_store.h"
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h" // IWYU pragma: export
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
@ -494,11 +494,6 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
|
||||||
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Packed>
|
|
||||||
constexpr bool IsF32() {
|
|
||||||
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
|
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
|
||||||
|
|
@ -678,8 +673,8 @@ class Compressor {
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||||
const float* HWY_RESTRICT weights) {
|
const float* HWY_RESTRICT weights) {
|
||||||
int num_weights = compressed->NumElements();
|
size_t num_weights = compressed->NumElements();
|
||||||
int num_compressed = compressed->NumElements();
|
size_t num_compressed = compressed->NumElements();
|
||||||
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
||||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||||
num_weights / (1000 * 1000));
|
num_weights / (1000 * 1000));
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "util/allocator.h"
|
#include "util/basics.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ cc_library(
|
||||||
"//third_party/absl/types:span",
|
"//third_party/absl/types:span",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,6 +32,11 @@ namespace gcpp {
|
||||||
|
|
||||||
using BF16 = hwy::bfloat16_t;
|
using BF16 = hwy::bfloat16_t;
|
||||||
|
|
||||||
|
template <typename Packed>
|
||||||
|
constexpr bool IsF32() {
|
||||||
|
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
|
||||||
|
}
|
||||||
|
|
||||||
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
|
||||||
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
|
||||||
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
// It supports seeking at a granularity of 1 and decoding to bf16/f32.
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ cc_binary(
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//:threading",
|
"//:threading",
|
||||||
"//:tokenizer",
|
"//:tokenizer",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -20,8 +20,9 @@
|
||||||
|
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "compression/shared.h" // BF16
|
||||||
#include "util/allocator.h" // RowVectorBatch
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
|
#include "util/allocator.h" // RowVectorBatch
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/base.h" // HWY_DASSERT
|
#include "hwy/base.h" // HWY_DASSERT
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -41,7 +42,7 @@ struct Activations {
|
||||||
RowVectorBatch<float> att_sums;
|
RowVectorBatch<float> att_sums;
|
||||||
|
|
||||||
// Gated FFW
|
// Gated FFW
|
||||||
RowVectorBatch<hwy::bfloat16_t> bf_pre_ffw_rms_out;
|
RowVectorBatch<BF16> bf_pre_ffw_rms_out;
|
||||||
RowVectorBatch<float> C1;
|
RowVectorBatch<float> C1;
|
||||||
RowVectorBatch<float> C2;
|
RowVectorBatch<float> C2;
|
||||||
RowVectorBatch<float> ffw_out;
|
RowVectorBatch<float> ffw_out;
|
||||||
|
|
@ -106,7 +107,7 @@ struct Activations {
|
||||||
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
|
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
|
||||||
att_sums = RowVectorBatch<float>(batch_size, kModelDim);
|
att_sums = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
|
||||||
bf_pre_ffw_rms_out = RowVectorBatch<hwy::bfloat16_t>(batch_size, kModelDim);
|
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, kModelDim);
|
||||||
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||||
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
|
||||||
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
|
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);
|
||||||
|
|
|
||||||
|
|
@ -118,8 +118,8 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||||
return CallForModel<float, FuncT, TArgs...>( //
|
return CallForModel<float, FuncT, TArgs...>( //
|
||||||
model, std::forward<TArgs>(args)...);
|
model, std::forward<TArgs>(args)...);
|
||||||
case Type::kBF16:
|
case Type::kBF16:
|
||||||
return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>(
|
return CallForModel<BF16, FuncT, TArgs...>(model,
|
||||||
model, std::forward<TArgs>(args)...);
|
std::forward<TArgs>(args)...);
|
||||||
case Type::kSFP:
|
case Type::kSFP:
|
||||||
return CallForModel<SfpStream, FuncT, TArgs...>(
|
return CallForModel<SfpStream, FuncT, TArgs...>(
|
||||||
model, std::forward<TArgs>(args)...);
|
model, std::forward<TArgs>(args)...);
|
||||||
|
|
@ -130,7 +130,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||||
|
|
||||||
#define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \
|
#define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \
|
||||||
X(CONFIGT, float) \
|
X(CONFIGT, float) \
|
||||||
X(CONFIGT, hwy::bfloat16_t) \
|
X(CONFIGT, BF16) \
|
||||||
X(CONFIGT, SfpStream)
|
X(CONFIGT, SfpStream)
|
||||||
|
|
||||||
#define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \
|
#define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \
|
||||||
|
|
@ -205,7 +205,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
|
||||||
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
|
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
|
||||||
break; \
|
break; \
|
||||||
case Type::kBF16: \
|
case Type::kBF16: \
|
||||||
GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \
|
GEMMA_DISPATCH_MODEL(MODEL, BF16, FUNC, ARGS); \
|
||||||
break; \
|
break; \
|
||||||
case Type::kSFP: \
|
case Type::kSFP: \
|
||||||
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
|
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
|
||||||
|
|
@ -239,15 +239,15 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
|
||||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
return hwy::ConvertScalarTo<float>(
|
||||||
Sqrt(static_cast<float>(TConfig::kModelDim))));
|
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(TConfig::kModelDim))));
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
|
||||||
size_t model_dim) {
|
size_t model_dim) {
|
||||||
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
// Round to bf16 to match Gemma's Embedder, which casts before mul.
|
||||||
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
|
return hwy::ConvertScalarTo<float>(
|
||||||
Sqrt(static_cast<float>(model_dim))));
|
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(model_dim))));
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,7 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
#include "compression/shared.h" // BF16
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
|
||||||
static constexpr size_t kTopK = GEMMA_TOPK;
|
static constexpr size_t kTopK = GEMMA_TOPK;
|
||||||
static constexpr size_t kVocabSize = 256000;
|
static constexpr size_t kVocabSize = 256000;
|
||||||
|
|
||||||
using EmbedderInputT = hwy::bfloat16_t;
|
using EmbedderInputT = BF16;
|
||||||
|
|
||||||
enum class LayerAttentionType {
|
enum class LayerAttentionType {
|
||||||
kGemma,
|
kGemma,
|
||||||
|
|
|
||||||
|
|
@ -763,7 +763,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
|
||||||
EmbeddingScaling<TConfig>();
|
EmbeddingScaling<TConfig>();
|
||||||
|
|
||||||
HWY_DASSERT(token >= 0);
|
HWY_DASSERT(token >= 0);
|
||||||
HWY_DASSERT(token < kVocabSize);
|
HWY_DASSERT(token < static_cast<int>(kVocabSize));
|
||||||
|
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
DecompressAndZeroPad(
|
DecompressAndZeroPad(
|
||||||
|
|
@ -1193,14 +1193,15 @@ SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
|
|
||||||
// Fast path for top-1 with no accept_token.
|
// Fast path for top-1 with no accept_token.
|
||||||
if (kTopK == 1 && !runtime_config.accept_token) {
|
if (kTopK == 1 && !runtime_config.accept_token) {
|
||||||
return [](float* logits, size_t vocab_size) -> TokenAndProb {
|
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE("Gen.Sample Top1");
|
PROFILER_ZONE("Gen.Sample Top1");
|
||||||
return Top1OfSoftmax(logits, vocab_size);
|
return Top1OfSoftmax(logits, vocab_size);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// General case: Softmax with top-k sampling.
|
// General case: Softmax with top-k sampling.
|
||||||
return [&runtime_config](float* logits, size_t vocab_size) -> TokenAndProb {
|
return [&runtime_config](float* logits,
|
||||||
|
size_t vocab_size) HWY_ATTR -> TokenAndProb {
|
||||||
PROFILER_ZONE("Gen.Sample general");
|
PROFILER_ZONE("Gen.Sample general");
|
||||||
Softmax(logits, vocab_size);
|
Softmax(logits, vocab_size);
|
||||||
const int token = SampleTopK<kTopK>(logits, vocab_size, *runtime_config.gen,
|
const int token = SampleTopK<kTopK>(logits, vocab_size, *runtime_config.gen,
|
||||||
|
|
|
||||||
|
|
@ -28,13 +28,13 @@
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/kv_cache.h"
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h" // RowVectorBatch
|
||||||
|
#include "util/basics.h" // TokenAndProb
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h" // hwy::bfloat16_t
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
using PromptTokens = hwy::Span<const int>;
|
using PromptTokens = hwy::Span<const int>;
|
||||||
|
|
|
||||||
|
|
@ -108,10 +108,10 @@ struct CompressedLayer {
|
||||||
// do not yet support smaller compressed types, or require at least bf16. When
|
// do not yet support smaller compressed types, or require at least bf16. When
|
||||||
// weights are f32, we also want such tensors to be f32.
|
// weights are f32, we also want such tensors to be f32.
|
||||||
// If weights are complex, this is also complex.
|
// If weights are complex, this is also complex.
|
||||||
using WeightF32OrBF16 = hwy::If<
|
using WeightF32OrBF16 =
|
||||||
hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
|
hwy::If<hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
|
||||||
hwy::If<hwy::IsSame<Weight, double>(), double,
|
hwy::If<hwy::IsSame<Weight, double>(), double,
|
||||||
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>>>;
|
hwy::If<IsF32<Weight>(), float, BF16>>>;
|
||||||
|
|
||||||
static constexpr size_t kHeads = TConfig::kHeads;
|
static constexpr size_t kHeads = TConfig::kHeads;
|
||||||
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
static constexpr size_t kKVHeads = TConfig::kKVHeads;
|
||||||
|
|
@ -363,9 +363,8 @@ struct CompressedWeights {
|
||||||
|
|
||||||
using Weight = typename TConfig::Weight;
|
using Weight = typename TConfig::Weight;
|
||||||
using WeightF32OrBF16 = typename CompressedLayer<TConfig>::WeightF32OrBF16;
|
using WeightF32OrBF16 = typename CompressedLayer<TConfig>::WeightF32OrBF16;
|
||||||
using WeightF32OrInputT =
|
using WeightF32OrInputT = hwy::If<hwy::IsSame<WeightF32OrBF16, BF16>(),
|
||||||
hwy::If<hwy::IsSame<WeightF32OrBF16, hwy::bfloat16_t>(), EmbedderInputT,
|
EmbedderInputT, WeightF32OrBF16>;
|
||||||
WeightF32OrBF16>;
|
|
||||||
|
|
||||||
MatPtrT<WeightF32OrInputT> embedder_input_embedding;
|
MatPtrT<WeightF32OrInputT> embedder_input_embedding;
|
||||||
MatPtrT<WeightF32OrBF16> final_norm_scale;
|
MatPtrT<WeightF32OrBF16> final_norm_scale;
|
||||||
|
|
|
||||||
|
|
@ -840,7 +840,7 @@ class DotStats {
|
||||||
|
|
||||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
|
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
|
||||||
// Extremely high error on aarch64.
|
// Extremely high error on aarch64.
|
||||||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f);
|
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Backward relative error, lower is better.
|
// Backward relative error, lower is better.
|
||||||
|
|
|
||||||
|
|
@ -19,9 +19,10 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "util/allocator.h" // RowVectorBatch
|
#include "util/allocator.h" // RowVectorBatch
|
||||||
#include "util/threading.h" // PerClusterPools
|
#include "util/threading.h"
|
||||||
|
#include "hwy/aligned_allocator.h" // IWYU pragma: export
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: export
|
||||||
#include "hwy/per_target.h"
|
#include "hwy/per_target.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
180
ops/ops-inl.h
180
ops/ops-inl.h
|
|
@ -28,7 +28,7 @@
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "util/allocator.h" // TokenAndProb
|
#include "util/basics.h" // TokenAndProb
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
|
|
@ -44,6 +44,7 @@
|
||||||
|
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "ops/dot-inl.h"
|
#include "ops/dot-inl.h"
|
||||||
|
#include "ops/sum-inl.h"
|
||||||
#include "hwy/contrib/algo/transform-inl.h"
|
#include "hwy/contrib/algo/transform-inl.h"
|
||||||
#include "hwy/contrib/math/math-inl.h"
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
#include "hwy/profiler.h" // also uses SIMD
|
#include "hwy/profiler.h" // also uses SIMD
|
||||||
|
|
@ -507,183 +508,6 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
MulByConstAndAdd(c, x, out, size, size);
|
MulByConstAndAdd(c, x, out, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
|
|
||||||
// of f32 sums.
|
|
||||||
struct SumKernelDouble {
|
|
||||||
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
|
|
||||||
// to be `float` in order to have `Raw = double`. Note that if either type is
|
|
||||||
// smaller than `float`, we may demote the other type from `float` to `BF16`.
|
|
||||||
template <typename VT, typename WT>
|
|
||||||
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
|
|
||||||
using State = double;
|
|
||||||
|
|
||||||
// Raw = double
|
|
||||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
|
||||||
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
|
|
||||||
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
|
|
||||||
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
|
|
||||||
sum0 = hn::Add(sum0, w0);
|
|
||||||
sum1 = hn::Add(sum1, w1);
|
|
||||||
sum2 = hn::Add(sum2, w2);
|
|
||||||
sum3 = hn::Add(sum3, w3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Raw = BF16
|
|
||||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
|
||||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
|
||||||
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
|
|
||||||
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
|
|
||||||
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
|
|
||||||
const hn::Repartition<float, DRaw> df;
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
|
||||||
// Reduce to two f32 sums so we can promote them to four f64 vectors.
|
|
||||||
VF sum02, sum13;
|
|
||||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
|
||||||
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
|
|
||||||
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
|
|
||||||
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
|
|
||||||
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
|
|
||||||
VF unused0 = hn::Zero(df);
|
|
||||||
VF unused1 = hn::Zero(df);
|
|
||||||
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
|
|
||||||
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
|
|
||||||
} else {
|
|
||||||
// If not native, the multiplication costs extra, so convert to f32.
|
|
||||||
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
|
|
||||||
const VF fe0 = hn::PromoteEvenTo(df, w0);
|
|
||||||
const VF fe1 = hn::PromoteEvenTo(df, w1);
|
|
||||||
const VF fe2 = hn::PromoteEvenTo(df, w2);
|
|
||||||
const VF fe3 = hn::PromoteEvenTo(df, w3);
|
|
||||||
const VF fo0 = hn::PromoteOddTo(df, w0);
|
|
||||||
const VF fo1 = hn::PromoteOddTo(df, w1);
|
|
||||||
const VF fo2 = hn::PromoteOddTo(df, w2);
|
|
||||||
const VF fo3 = hn::PromoteOddTo(df, w3);
|
|
||||||
const VF fe01 = hn::Add(fe0, fe1);
|
|
||||||
const VF fe23 = hn::Add(fe2, fe3);
|
|
||||||
const VF fo01 = hn::Add(fo0, fo1);
|
|
||||||
const VF fo23 = hn::Add(fo2, fo3);
|
|
||||||
sum02 = hn::Add(fe01, fe23);
|
|
||||||
sum13 = hn::Add(fo01, fo23);
|
|
||||||
}
|
|
||||||
|
|
||||||
const DS ds;
|
|
||||||
const VS d0 = hn::PromoteLowerTo(ds, sum02);
|
|
||||||
const VS d1 = hn::PromoteUpperTo(ds, sum02);
|
|
||||||
const VS d2 = hn::PromoteLowerTo(ds, sum13);
|
|
||||||
const VS d3 = hn::PromoteUpperTo(ds, sum13);
|
|
||||||
|
|
||||||
sum0 = hn::Add(sum0, d0);
|
|
||||||
sum1 = hn::Add(sum1, d1);
|
|
||||||
sum2 = hn::Add(sum2, d2);
|
|
||||||
sum3 = hn::Add(sum3, d3);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Raw = double
|
|
||||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
|
||||||
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
|
|
||||||
VR& comp0) const {
|
|
||||||
sum0 = hn::Add(sum0, w0);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Raw = BF16
|
|
||||||
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
|
||||||
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
|
||||||
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
|
|
||||||
VS& extra0) const {
|
|
||||||
const hn::Repartition<float, DRaw> df;
|
|
||||||
using VF = hn::Vec<decltype(df)>;
|
|
||||||
VF f0;
|
|
||||||
if constexpr (HWY_NATIVE_DOT_BF16) {
|
|
||||||
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
|
|
||||||
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
|
|
||||||
} else {
|
|
||||||
const VF fe0 = hn::PromoteEvenTo(df, w0);
|
|
||||||
const VF fo0 = hn::PromoteOddTo(df, w0);
|
|
||||||
f0 = hn::Add(fe0, fo0);
|
|
||||||
}
|
|
||||||
|
|
||||||
const DS ds;
|
|
||||||
const VS d0 = hn::PromoteLowerTo(ds, f0);
|
|
||||||
const VS d1 = hn::PromoteUpperTo(ds, f0);
|
|
||||||
|
|
||||||
sum0 = hn::Add(sum0, d0);
|
|
||||||
extra0 = hn::Add(extra0, d1);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class DState, class VS = hn::Vec<DState>>
|
|
||||||
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
|
||||||
VS& extra0, VS&, VS&, VS&) const {
|
|
||||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
|
||||||
sum0 = hn::Add(sum0, sum1);
|
|
||||||
sum2 = hn::Add(sum2, sum3);
|
|
||||||
sum0 = hn::Add(sum0, extra0); // from Update1
|
|
||||||
sum0 = hn::Add(sum0, sum2);
|
|
||||||
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
|
|
||||||
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
|
|
||||||
// instead of FastTwoSums because the magnitude of the initial sum is not
|
|
||||||
// always greater than the next input, and this does actually change the e2e
|
|
||||||
// generation results. Note that Kahan summation differs in that it first adds
|
|
||||||
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
|
|
||||||
// and comp* here have shorter dependency chains.
|
|
||||||
//
|
|
||||||
// This about as accurate as SumKernelDouble but slower, hence we only use this
|
|
||||||
// if f64 is not supported on this target.
|
|
||||||
struct SumKernelCascaded {
|
|
||||||
template <typename VT, typename WT>
|
|
||||||
using Raw = float;
|
|
||||||
using State = float;
|
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
|
||||||
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
|
||||||
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
|
|
||||||
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
|
|
||||||
VF& comp3) const {
|
|
||||||
VF serr0, serr1, serr2, serr3;
|
|
||||||
sum0 = TwoSums(df, sum0, w0, serr0);
|
|
||||||
sum1 = TwoSums(df, sum1, w1, serr1);
|
|
||||||
sum2 = TwoSums(df, sum2, w2, serr2);
|
|
||||||
sum3 = TwoSums(df, sum3, w3, serr3);
|
|
||||||
|
|
||||||
comp0 = hn::Add(comp0, serr0);
|
|
||||||
comp1 = hn::Add(comp1, serr1);
|
|
||||||
comp2 = hn::Add(comp2, serr2);
|
|
||||||
comp3 = hn::Add(comp3, serr3);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
|
||||||
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
|
||||||
VF& comp0) const {
|
|
||||||
VF serr0;
|
|
||||||
sum0 = TwoSums(df, sum0, w0, serr0);
|
|
||||||
|
|
||||||
comp0 = hn::Add(comp0, serr0);
|
|
||||||
}
|
|
||||||
|
|
||||||
template <class DF, class VF = hn::Vec<DF>>
|
|
||||||
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
|
||||||
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
|
||||||
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
|
||||||
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
|
||||||
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
|
||||||
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
|
||||||
return ReduceCascadedSums(df, sum0, comp0);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
using SumKernelDefault =
|
|
||||||
hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;
|
|
||||||
|
|
||||||
template <class D, typename VT>
|
|
||||||
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
|
|
||||||
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
|
|
||||||
const hn::Repartition<Raw, D> d_raw;
|
|
||||||
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
|
|
||||||
}
|
|
||||||
|
|
||||||
// See below for a specialized version for top-1 sampling.
|
// See below for a specialized version for top-1 sampling.
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos) {
|
const size_t mask_pos) {
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,217 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
// Include guard for SIMD code.
|
||||||
|
#if defined(THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE) == defined(HWY_TARGET_TOGGLE)
|
||||||
|
#ifdef THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE
|
||||||
|
#undef THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE
|
||||||
|
#else
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE
|
||||||
|
#endif
|
||||||
|
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
|
|
||||||
|
HWY_BEFORE_NAMESPACE();
|
||||||
|
namespace gcpp {
|
||||||
|
namespace HWY_NAMESPACE {
|
||||||
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
|
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
|
||||||
|
// of f32 sums.
|
||||||
|
struct SumKernelDouble {
|
||||||
|
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
|
||||||
|
// to be `float` in order to have `Raw = double`. Note that if either type is
|
||||||
|
// smaller than `float`, we may demote the other type from `float` to `BF16`.
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
|
||||||
|
using State = double;
|
||||||
|
|
||||||
|
// Raw = double
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||||
|
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
|
||||||
|
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
|
||||||
|
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
|
||||||
|
sum0 = hn::Add(sum0, w0);
|
||||||
|
sum1 = hn::Add(sum1, w1);
|
||||||
|
sum2 = hn::Add(sum2, w2);
|
||||||
|
sum3 = hn::Add(sum3, w3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw = BF16
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||||
|
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||||
|
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
|
||||||
|
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
|
||||||
|
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
|
||||||
|
const hn::Repartition<float, DRaw> df;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
// Reduce to two f32 sums so we can promote them to four f64 vectors.
|
||||||
|
VF sum02, sum13;
|
||||||
|
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||||
|
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
|
||||||
|
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
|
||||||
|
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
|
||||||
|
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
|
||||||
|
VF unused0 = hn::Zero(df);
|
||||||
|
VF unused1 = hn::Zero(df);
|
||||||
|
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
|
||||||
|
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
|
||||||
|
} else {
|
||||||
|
// If not native, the multiplication costs extra, so convert to f32.
|
||||||
|
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
|
||||||
|
const VF fe0 = hn::PromoteEvenTo(df, w0);
|
||||||
|
const VF fe1 = hn::PromoteEvenTo(df, w1);
|
||||||
|
const VF fe2 = hn::PromoteEvenTo(df, w2);
|
||||||
|
const VF fe3 = hn::PromoteEvenTo(df, w3);
|
||||||
|
const VF fo0 = hn::PromoteOddTo(df, w0);
|
||||||
|
const VF fo1 = hn::PromoteOddTo(df, w1);
|
||||||
|
const VF fo2 = hn::PromoteOddTo(df, w2);
|
||||||
|
const VF fo3 = hn::PromoteOddTo(df, w3);
|
||||||
|
const VF fe01 = hn::Add(fe0, fe1);
|
||||||
|
const VF fe23 = hn::Add(fe2, fe3);
|
||||||
|
const VF fo01 = hn::Add(fo0, fo1);
|
||||||
|
const VF fo23 = hn::Add(fo2, fo3);
|
||||||
|
sum02 = hn::Add(fe01, fe23);
|
||||||
|
sum13 = hn::Add(fo01, fo23);
|
||||||
|
}
|
||||||
|
|
||||||
|
const DS ds;
|
||||||
|
const VS d0 = hn::PromoteLowerTo(ds, sum02);
|
||||||
|
const VS d1 = hn::PromoteUpperTo(ds, sum02);
|
||||||
|
const VS d2 = hn::PromoteLowerTo(ds, sum13);
|
||||||
|
const VS d3 = hn::PromoteUpperTo(ds, sum13);
|
||||||
|
|
||||||
|
sum0 = hn::Add(sum0, d0);
|
||||||
|
sum1 = hn::Add(sum1, d1);
|
||||||
|
sum2 = hn::Add(sum2, d2);
|
||||||
|
sum3 = hn::Add(sum3, d3);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw = double
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
|
||||||
|
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
|
||||||
|
VR& comp0) const {
|
||||||
|
sum0 = hn::Add(sum0, w0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw = BF16
|
||||||
|
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
|
||||||
|
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
|
||||||
|
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
|
||||||
|
VS& extra0) const {
|
||||||
|
const hn::Repartition<float, DRaw> df;
|
||||||
|
using VF = hn::Vec<decltype(df)>;
|
||||||
|
VF f0;
|
||||||
|
if constexpr (HWY_NATIVE_DOT_BF16) {
|
||||||
|
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
|
||||||
|
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
|
||||||
|
} else {
|
||||||
|
const VF fe0 = hn::PromoteEvenTo(df, w0);
|
||||||
|
const VF fo0 = hn::PromoteOddTo(df, w0);
|
||||||
|
f0 = hn::Add(fe0, fo0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const DS ds;
|
||||||
|
const VS d0 = hn::PromoteLowerTo(ds, f0);
|
||||||
|
const VS d1 = hn::PromoteUpperTo(ds, f0);
|
||||||
|
|
||||||
|
sum0 = hn::Add(sum0, d0);
|
||||||
|
extra0 = hn::Add(extra0, d1);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DState, class VS = hn::Vec<DState>>
|
||||||
|
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
|
||||||
|
VS& extra0, VS&, VS&, VS&) const {
|
||||||
|
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||||
|
sum0 = hn::Add(sum0, sum1);
|
||||||
|
sum2 = hn::Add(sum2, sum3);
|
||||||
|
sum0 = hn::Add(sum0, extra0); // from Update1
|
||||||
|
sum0 = hn::Add(sum0, sum2);
|
||||||
|
return static_cast<float>(hn::ReduceSum(dd, sum0));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
|
||||||
|
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
|
||||||
|
// instead of FastTwoSums because the magnitude of the initial sum is not
|
||||||
|
// always greater than the next input, and this does actually change the e2e
|
||||||
|
// generation results. Note that Kahan summation differs in that it first adds
|
||||||
|
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
|
||||||
|
// and comp* here have shorter dependency chains.
|
||||||
|
//
|
||||||
|
// This about as accurate as SumKernelDouble but slower, hence we only use this
|
||||||
|
// if f64 is not supported on this target.
|
||||||
|
struct SumKernelCascaded {
|
||||||
|
template <typename VT, typename WT>
|
||||||
|
using Raw = float;
|
||||||
|
using State = float;
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
|
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
|
||||||
|
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
|
||||||
|
VF& comp3) const {
|
||||||
|
VF serr0, serr1, serr2, serr3;
|
||||||
|
sum0 = TwoSums(df, sum0, w0, serr0);
|
||||||
|
sum1 = TwoSums(df, sum1, w1, serr1);
|
||||||
|
sum2 = TwoSums(df, sum2, w2, serr2);
|
||||||
|
sum3 = TwoSums(df, sum3, w3, serr3);
|
||||||
|
|
||||||
|
comp0 = hn::Add(comp0, serr0);
|
||||||
|
comp1 = hn::Add(comp1, serr1);
|
||||||
|
comp2 = hn::Add(comp2, serr2);
|
||||||
|
comp3 = hn::Add(comp3, serr3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
|
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||||
|
VF& comp0) const {
|
||||||
|
VF serr0;
|
||||||
|
sum0 = TwoSums(df, sum0, w0, serr0);
|
||||||
|
|
||||||
|
comp0 = hn::Add(comp0, serr0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||||
|
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||||
|
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||||
|
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||||
|
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||||
|
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
||||||
|
return ReduceCascadedSums(df, sum0, comp0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
using SumKernelDefault =
|
||||||
|
hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;
|
||||||
|
|
||||||
|
template <class D, typename VT>
|
||||||
|
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
|
||||||
|
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
|
||||||
|
const hn::Repartition<Raw, D> d_raw;
|
||||||
|
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
|
||||||
|
}
|
||||||
|
|
||||||
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
} // namespace HWY_NAMESPACE
|
||||||
|
} // namespace gcpp
|
||||||
|
HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
|
#endif // NOLINT
|
||||||
|
|
@ -11,7 +11,7 @@ cc_library(
|
||||||
name = "image",
|
name = "image",
|
||||||
srcs = ["image.cc"],
|
srcs = ["image.cc"],
|
||||||
hdrs = ["image.h"],
|
hdrs = ["image.h"],
|
||||||
deps = ["@hwy//:hwy"],
|
deps = ["@highway//:hwy"],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
|
|
@ -39,7 +39,7 @@ cc_test(
|
||||||
"//:common",
|
"//:common",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//:tokenizer",
|
"//:tokenizer",
|
||||||
"@hwy//:hwy",
|
"@highway//:hwy",
|
||||||
"@hwy//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -19,30 +19,11 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h" // IWYU pragma: export
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
#if HWY_IS_MSAN
|
|
||||||
#include <sanitizer/msan_interface.h>
|
|
||||||
#endif
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
|
||||||
#if HWY_IS_MSAN
|
|
||||||
__msan_check_mem_is_initialized(ptr, size);
|
|
||||||
#else
|
|
||||||
(void)ptr;
|
|
||||||
(void)size;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
// Shared between gemma.h and ops-inl.h.
|
|
||||||
struct TokenAndProb {
|
|
||||||
int token;
|
|
||||||
float prob;
|
|
||||||
};
|
|
||||||
|
|
||||||
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
|
|
|
||||||
|
|
@ -82,6 +82,9 @@ class AppArgs : public ArgsBase<AppArgs> {
|
||||||
visitor(max_threads, "num_threads", size_t{0},
|
visitor(max_threads, "num_threads", size_t{0},
|
||||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||||
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||||
|
// These can be used to partition CPU sockets/packages and their
|
||||||
|
// clusters/CCXs across several program instances. The default is to use
|
||||||
|
// all available resources.
|
||||||
visitor(skip_packages, "skip_packages", size_t{0},
|
visitor(skip_packages, "skip_packages", size_t{0},
|
||||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||||
visitor(max_packages, "max_packages", size_t{0},
|
visitor(max_packages, "max_packages", size_t{0},
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,49 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||||
|
|
||||||
|
// IWYU pragma: begin_exports
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "hwy/base.h" // HWY_IS_MSAN
|
||||||
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
#if HWY_IS_MSAN
|
||||||
|
#include <sanitizer/msan_interface.h>
|
||||||
|
#endif
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
|
||||||
|
#if HWY_IS_MSAN
|
||||||
|
__msan_check_mem_is_initialized(ptr, size);
|
||||||
|
#else
|
||||||
|
(void)ptr;
|
||||||
|
(void)size;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
|
// Shared between gemma.h and ops-inl.h.
|
||||||
|
struct TokenAndProb {
|
||||||
|
int token;
|
||||||
|
float prob;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
|
||||||
|
|
@ -381,7 +381,7 @@ class BoundedTopology {
|
||||||
LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl.
|
LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl.
|
||||||
bool missing_cluster = false;
|
bool missing_cluster = false;
|
||||||
|
|
||||||
if (HWY_LIKELY(have_threading_support)) {
|
if (HWY_LIKELY(have_threading_support && !topology_.packages.empty())) {
|
||||||
(void)GetThreadAffinity(enabled_lps); // failure = all disabled
|
(void)GetThreadAffinity(enabled_lps); // failure = all disabled
|
||||||
|
|
||||||
// No effect if topology is unknown or `enabled_lps` is empty.
|
// No effect if topology is unknown or `enabled_lps` is empty.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue