Minor cleanup, Windows+Bazel build fixes

add app.h comment
compress-inl: remove unused typedef
gemma-inl: add missing HWY_ATTR and cast
separate sum-inl.h and basics.h headers
replace more hwy::bfloat16_t with BF16
update include pragmas
update dot_test thresholds
update Highway version in Bazel for HWY_RCAST_ALIGNED fix
PiperOrigin-RevId: 684464326
This commit is contained in:
Jan Wassenberg 2024-10-10 09:04:19 -07:00 committed by Copybara-Service
parent 85958f5fd3
commit 6ab3ff5bde
25 changed files with 475 additions and 403 deletions

View File

@ -20,11 +20,19 @@ licenses(["notice"])
exports_files(["LICENSE"])
cc_library(
name = "basics",
hdrs = ["util/basics.h"],
deps = [
"@highway//:hwy",
],
)
cc_library(
name = "allocator",
hdrs = ["util/allocator.h"],
deps = [
"@hwy//:hwy",
"@highway//:hwy",
],
)
@ -32,9 +40,9 @@ cc_library(
name = "test_util",
hdrs = ["util/test_util.h"],
deps = [
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:stats",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:stats",
],
)
@ -42,9 +50,9 @@ cc_library(
name = "threading",
hdrs = ["util/threading.h"],
deps = [
"@hwy//:hwy",
"@hwy//:thread_pool",
"@hwy//:topology",
"@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
],
)
@ -54,8 +62,8 @@ cc_test(
deps = [
":threading",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
@ -66,6 +74,7 @@ cc_library(
],
textual_hdrs = [
"ops/dot-inl.h",
"ops/sum-inl.h",
"ops/fp_arith-inl.h",
"ops/matmul-inl.h",
"ops/matvec-inl.h",
@ -73,14 +82,15 @@ cc_library(
],
deps = [
":allocator",
":basics",
":threading",
"//compression:compress",
"@hwy//:algo",
"@hwy//:hwy",
"@hwy//:math",
"@hwy//:matvec",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@highway//:algo",
"@highway//:hwy",
"@highway//:math",
"@highway//:matvec",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -102,12 +112,12 @@ cc_test(
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", #buildcleaner: keep
"@hwy//:profiler",
"@hwy//:stats",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark", #buildcleaner: keep
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
)
@ -127,9 +137,9 @@ cc_test(
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", #buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark", #buildcleaner: keep
],
)
@ -145,9 +155,9 @@ cc_test(
":ops",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
@ -164,10 +174,10 @@ cc_test(
":threading",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
@ -180,8 +190,8 @@ cc_library(
],
deps = [
"//compression:compress",
"@hwy//:hwy", # base.h
"@hwy//:thread_pool",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
)
@ -194,10 +204,10 @@ cc_library(
":common",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:stats",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
)
@ -208,9 +218,9 @@ cc_library(
deps = [
":common",
"//compression:io",
"@hwy//:hwy",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@highway//:hwy",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@com_google_sentencepiece//:sentencepiece_processor",
],
)
@ -221,7 +231,7 @@ cc_library(
hdrs = ["gemma/kv_cache.h"],
deps = [
":common",
"@hwy//:hwy",
"@highway//:hwy",
],
)
@ -268,6 +278,7 @@ cc_library(
],
deps = [
":allocator",
":basics",
":common",
":ops",
":tokenizer",
@ -275,12 +286,13 @@ cc_library(
":weights",
":threading",
"//compression:io",
"//compression:sfp",
"//paligemma:image",
"@hwy//:hwy",
"@hwy//:bit_set",
"@hwy//:nanobenchmark", # timer
"@hwy//:profiler",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:bit_set",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -292,7 +304,7 @@ cc_library(
":common",
":gemma_lib",
":ops",
"@hwy//:hwy",
"@highway//:hwy",
],
)
@ -301,7 +313,7 @@ cc_library(
hdrs = ["util/args.h"],
deps = [
"//compression:io",
"@hwy//:hwy",
"@highway//:hwy",
],
)
@ -314,9 +326,9 @@ cc_library(
":gemma_lib",
":threading",
"//compression:io",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@hwy//:topology",
"@highway//:hwy",
"@highway//:thread_pool",
"@highway//:topology",
],
)
@ -333,12 +345,12 @@ cc_library(
":kv_cache",
":threading",
# Placeholder for internal dep, do not remove.,
"@benchmark//:benchmark",
"@google_benchmark//:benchmark",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
"@hwy//:topology",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
"@highway//:topology",
],
)
@ -357,8 +369,8 @@ cc_test(
":gemma_lib",
":tokenizer",
"@googletest//:gtest_main",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
@ -374,9 +386,9 @@ cc_binary(
":threading",
# Placeholder for internal dep, do not remove.,
"//paligemma:image",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
],
)
@ -391,9 +403,9 @@ cc_binary(
":cross_entropy",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
"@nlohmann_json//:json",
],
)
@ -403,7 +415,7 @@ cc_binary(
srcs = ["evals/benchmarks.cc"],
deps = [
":benchmark_helper",
"@benchmark//:benchmark",
"@google_benchmark//:benchmark",
],
)
@ -418,8 +430,8 @@ cc_binary(
":benchmark_helper",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:thread_pool",
"@nlohmann_json//:json",
],
)
@ -433,9 +445,9 @@ cc_binary(
":benchmark_helper",
":gemma_lib",
"//compression:io",
"@hwy//:hwy",
"@hwy//:profiler",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:profiler",
"@highway//:thread_pool",
"@nlohmann_json//:json",
],
)
@ -477,9 +489,9 @@ cc_library(
":prompt",
":weights",
"//compression:compress",
"@hwy//:dot",
"@hwy//:hwy", # base.h
"@hwy//:thread_pool",
"@highway//:dot",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
],
)
@ -497,7 +509,7 @@ cc_library(
":prompt",
":weights",
"//compression:compress",
"@hwy//:hwy",
"@highway//:hwy",
],
)
@ -517,7 +529,7 @@ cc_test(
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@hwy//:thread_pool",
"@highway//:thread_pool",
],
)
@ -544,9 +556,9 @@ cc_test(
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
@ -559,8 +571,8 @@ cc_library(
":common",
":weights",
"//compression:compress",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
@ -583,6 +595,6 @@ cc_test(
":threading",
":weights",
"@googletest//:gtest_main",
"@hwy//:thread_pool",
"@highway//:thread_pool",
],
)

View File

@ -22,7 +22,7 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG bb6c3f36b0c8dde8a8ef98b0f0884f4de820a7ca EXCLUDE_FROM_ALL)
FetchContent_Declare(highway GIT_REPOSITORY https://github.com/google/highway.git GIT_TAG 0ca297227a373710e76dd45e0ad4d68adb6928fe EXCLUDE_FROM_ALL)
FetchContent_MakeAvailable(highway)
## Note: absl needs to be installed by sentencepiece. This will only happen if
@ -106,11 +106,13 @@ set(SOURCES
ops/matmul-inl.h
ops/matvec-inl.h
ops/ops-inl.h
ops/sum-inl.h
paligemma/image.cc
paligemma/image.h
util/allocator.h
util/app.h
util/args.h
util/basics.h
util/test_util.h
util/threading.h
)

View File

@ -3,37 +3,33 @@ module(
version = "0.1.0",
)
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "googletest", version = "1.14.0")
# Copied from Highway because Bazel does not load them transitively
bazel_dep(name = "bazel_skylib", version = "1.4.1")
bazel_dep(name = "abseil-cpp", version = "20240722.0")
bazel_dep(name = "bazel_skylib", version = "1.6.1")
bazel_dep(name = "googletest", version = "1.15.2")
bazel_dep(name = "highway", version = "1.1.0")
bazel_dep(name = "nlohmann_json", version = "3.11.3")
bazel_dep(name = "platforms", version = "0.0.10")
bazel_dep(name = "rules_cc", version = "0.0.9")
bazel_dep(name = "platforms", version = "0.0.7")
bazel_dep(name = "rules_license", version = "0.0.7")
bazel_dep(name = "google_benchmark", version = "1.8.5")
# Require a more recent version.
git_override(
module_name = "highway",
commit = "0ca297227a373710e76dd45e0ad4d68adb6928fe",
remote = "https://github.com/google/highway",
)
http_archive = use_repo_rule("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
http_archive(
name = "hwy",
urls = ["https://github.com/google/highway/archive/refs/tags/1.2.0.zip"],
integrity = "sha256-fbtKAGj5hhhBr5Bggtsrj4aIodC2OHb1njB8LGfom8A=", strip_prefix = "highway-1.2.0",
)
http_archive(
name = "nlohmann_json",
urls = ["https://github.com/nlohmann/json/archive/refs/tags/v3.11.3.zip"],
integrity = "sha256-BAIrBdgG61/3MCPCgLaGl9Erk+G3JnoLIqGjnsdXgGk=",
strip_prefix = "json-3.11.3",
)
http_archive(
name = "com_google_sentencepiece",
build_file = "@//bazel:sentencepiece.bazel",
patch_args = ["-p1"],
patches = ["@//bazel:sentencepiece.patch"],
sha256 = "8409b0126ebd62b256c685d5757150cf7fcb2b92a2f2b98efb3f38fc36719754",
strip_prefix = "sentencepiece-0.1.96",
urls = ["https://github.com/google/sentencepiece/archive/refs/tags/v0.1.96.zip"],
build_file = "@//bazel:sentencepiece.bazel",
patches = ["@//bazel:sentencepiece.patch"],
patch_args = ["-p1"],
)
# For sentencepiece
@ -56,17 +52,3 @@ cc_library(
"https://github.com/s-yata/darts-clone/archive/e40ce4627526985a7767444b6ed6893ab6ff8983.zip",
],
)
# ABSL on 2023-10-18
http_archive(
name = "com_google_absl",
sha256 = "f841f78243f179326f2a80b719f2887c38fe226d288ecdc46e2aa091e6aa43bc",
strip_prefix = "abseil-cpp-9687a8ea750bfcddf790372093245a1d041b21a3",
urls = ["https://github.com/abseil/abseil-cpp/archive//9687a8ea750bfcddf790372093245a1d041b21a3.tar.gz"],
)
# Benchmark
http_archive(
name = "benchmark",
urls = ["https://github.com/google/benchmark/archive/refs/tags/v1.8.2.tar.gz"],
integrity = "sha256-KqspgNA3YTf5adkoSPu2gharsHYzA0U0/IxlzE56DpM=",
strip_prefix = "benchmark-1.8.2",
)

View File

@ -42,7 +42,7 @@ cc_library(
"src/common.h",
],
deps = [
"@com_google_absl//absl/base",
"@abseil-cpp//absl/base",
],
)
@ -86,12 +86,12 @@ cc_library(
":common",
":sentencepiece_cc_proto",
":sentencepiece_model_cc_proto",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/memory",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@abseil-cpp//absl/container:flat_hash_map",
"@abseil-cpp//absl/container:flat_hash_set",
"@abseil-cpp//absl/memory",
"@abseil-cpp//absl/status",
"@abseil-cpp//absl/strings",
"@abseil-cpp//absl/strings:str_format",
"@darts_clone",
],
)

View File

@ -33,7 +33,7 @@ cc_library(
],
hdrs = ["io.h"],
deps = [
"@hwy//:hwy",
"@highway//:hwy",
] + FILE_DEPS,
)
@ -43,8 +43,8 @@ cc_library(
hdrs = ["blob_store.h"],
deps = [
":io",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
@ -55,9 +55,9 @@ cc_library(
"shared.h",
],
deps = [
"@hwy//:hwy",
"@hwy//:stats",
"@hwy//hwy/contrib/sort:vqsort",
"@highway//:hwy",
"@highway//:stats",
"@highway//hwy/contrib/sort:vqsort",
],
)
@ -69,8 +69,8 @@ cc_test(
":distortion",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark", # Unpredictable1
"@highway//:hwy_test_util",
"@highway//:nanobenchmark", # Unpredictable1
],
)
@ -79,7 +79,7 @@ cc_library(
hdrs = ["shared.h"],
textual_hdrs = ["sfp-inl.h"],
deps = [
"@hwy//:hwy",
"@highway//:hwy",
],
)
@ -89,9 +89,9 @@ cc_library(
textual_hdrs = ["nuq-inl.h"],
deps = [
":sfp",
"//:allocator",
"@hwy//:hwy",
"@hwy//hwy/contrib/sort:vqsort",
"//:basics",
"@highway//:hwy",
"@highway//hwy/contrib/sort:vqsort",
],
)
@ -103,8 +103,8 @@ cc_library(
deps = [
":compress",
":distortion",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)
@ -122,9 +122,9 @@ cc_test(
":sfp",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)
@ -144,9 +144,9 @@ cc_test(
":sfp",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)
@ -164,11 +164,11 @@ cc_library(
":io",
":nuq",
":sfp",
"@hwy//:hwy",
"@hwy//:nanobenchmark",
"@hwy//:profiler",
"@hwy//:stats",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:profiler",
"@highway//:stats",
"@highway//:thread_pool",
],
)
@ -188,9 +188,9 @@ cc_test(
":test_util",
"@googletest//:gtest_main", # buildcleaner: keep
"//:test_util",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
@ -201,10 +201,10 @@ cc_library(
deps = [
":nuq",
":sfp",
"@hwy//:hwy",
"@hwy//:stats",
"@hwy//:thread_pool",
"@hwy//hwy/contrib/sort:vqsort",
"@highway//:hwy",
"@highway//:stats",
"@highway//:thread_pool",
"@highway//hwy/contrib/sort:vqsort",
],
)
@ -218,7 +218,7 @@ cc_binary(
"//:args",
"//:common",
"//:weights",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

View File

@ -24,7 +24,7 @@
#include <cmath> // lroundf, only if COMPRESS_STATS
#include "compression/blob_store.h"
#include "compression/compress.h"
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
@ -494,11 +494,6 @@ void Compress2(DF df, VF raw0, VF raw1, const PackedSpan<Packed>& packed,
Traits::Store2(df, raw0, raw1, packed, packed_ofs);
}
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
namespace detail {
// Compile-time-only check that `DRaw` and `Packed` are compatible. This makes
@ -678,8 +673,8 @@ class Compressor {
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
const float* HWY_RESTRICT weights) {
int num_weights = compressed->NumElements();
int num_compressed = compressed->NumElements();
size_t num_weights = compressed->NumElements();
size_t num_compressed = compressed->NumElements();
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
num_weights / (1000 * 1000));

View File

@ -22,7 +22,7 @@
#include <stdio.h>
#include "compression/shared.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_

View File

@ -16,8 +16,8 @@ cc_library(
"//third_party/absl/types:span",
"//compression:compress",
"//compression:io",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

View File

@ -32,6 +32,11 @@ namespace gcpp {
using BF16 = hwy::bfloat16_t;
template <typename Packed>
constexpr bool IsF32() {
return hwy::IsSame<hwy::RemoveCvRef<Packed>, float>();
}
// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32
// inputs that combines the advantages of e4m3 and e5m2 into a single format.
// It supports seeking at a granularity of 1 and decoding to bf16/f32.

View File

@ -17,7 +17,7 @@ cc_binary(
"//:gemma_lib",
"//:threading",
"//:tokenizer",
"@hwy//:hwy",
"@hwy//:thread_pool",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

View File

@ -20,8 +20,9 @@
#include <cmath>
#include "ops/matmul.h" // MatMulEnv
#include "util/allocator.h" // RowVectorBatch
#include "compression/shared.h" // BF16
#include "ops/matmul.h" // MatMulEnv
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h"
#include "hwy/base.h" // HWY_DASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -41,7 +42,7 @@ struct Activations {
RowVectorBatch<float> att_sums;
// Gated FFW
RowVectorBatch<hwy::bfloat16_t> bf_pre_ffw_rms_out;
RowVectorBatch<BF16> bf_pre_ffw_rms_out;
RowVectorBatch<float> C1;
RowVectorBatch<float> C2;
RowVectorBatch<float> ffw_out;
@ -106,7 +107,7 @@ struct Activations {
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
att_sums = RowVectorBatch<float>(batch_size, kModelDim);
bf_pre_ffw_rms_out = RowVectorBatch<hwy::bfloat16_t>(batch_size, kModelDim);
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(batch_size, kModelDim);
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
C2 = RowVectorBatch<float>(batch_size, kFFHiddenDim);
ffw_out = RowVectorBatch<float>(batch_size, kModelDim);

View File

@ -118,8 +118,8 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
return CallForModel<float, FuncT, TArgs...>( //
model, std::forward<TArgs>(args)...);
case Type::kBF16:
return CallForModel<hwy::bfloat16_t, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
return CallForModel<BF16, FuncT, TArgs...>(model,
std::forward<TArgs>(args)...);
case Type::kSFP:
return CallForModel<SfpStream, FuncT, TArgs...>(
model, std::forward<TArgs>(args)...);
@ -130,7 +130,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
#define GEMMA_FOREACH_WEIGHT(X, CONFIGT) \
X(CONFIGT, float) \
X(CONFIGT, hwy::bfloat16_t) \
X(CONFIGT, BF16) \
X(CONFIGT, SfpStream)
#define GEMMA_FOREACH_CONFIG_AND_WEIGHT(X) \
@ -205,7 +205,7 @@ decltype(auto) CallForModelAndWeight(Model model, Type weight,
GEMMA_DISPATCH_MODEL(MODEL, float, FUNC, ARGS); \
break; \
case Type::kBF16: \
GEMMA_DISPATCH_MODEL(MODEL, hwy::bfloat16_t, FUNC, ARGS); \
GEMMA_DISPATCH_MODEL(MODEL, BF16, FUNC, ARGS); \
break; \
case Type::kSFP: \
GEMMA_DISPATCH_MODEL(MODEL, SfpStream, FUNC, ARGS); \
@ -239,15 +239,15 @@ static GEMMA_CONSTEXPR_SQRT HWY_INLINE float Sqrt(float x) { return sqrtf(x); }
template <typename TConfig>
GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling() {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(TConfig::kModelDim))));
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(TConfig::kModelDim))));
}
static HWY_INLINE GEMMA_CONSTEXPR_EMBSCALING float EmbeddingScaling(
size_t model_dim) {
// Round to bf16 to match Gemma's Embedder, which casts before mul.
return hwy::ConvertScalarTo<float>(hwy::ConvertScalarTo<hwy::bfloat16_t>(
Sqrt(static_cast<float>(model_dim))));
return hwy::ConvertScalarTo<float>(
hwy::ConvertScalarTo<BF16>(Sqrt(static_cast<float>(model_dim))));
}
template <class TConfig>

View File

@ -22,7 +22,7 @@
#include <array>
#include "hwy/base.h" // hwy::bfloat16_t
#include "compression/shared.h" // BF16
namespace gcpp {
@ -40,7 +40,7 @@ static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
static constexpr size_t kTopK = GEMMA_TOPK;
static constexpr size_t kVocabSize = 256000;
using EmbedderInputT = hwy::bfloat16_t;
using EmbedderInputT = BF16;
enum class LayerAttentionType {
kGemma,

View File

@ -763,7 +763,7 @@ HWY_NOINLINE void EmbedToken(int token, size_t batch_idx, size_t pos,
EmbeddingScaling<TConfig>();
HWY_DASSERT(token >= 0);
HWY_DASSERT(token < kVocabSize);
HWY_DASSERT(token < static_cast<int>(kVocabSize));
const hn::ScalableTag<float> df;
DecompressAndZeroPad(
@ -1193,14 +1193,15 @@ SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
// Fast path for top-1 with no accept_token.
if (kTopK == 1 && !runtime_config.accept_token) {
return [](float* logits, size_t vocab_size) -> TokenAndProb {
return [](float* logits, size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample Top1");
return Top1OfSoftmax(logits, vocab_size);
};
}
// General case: Softmax with top-k sampling.
return [&runtime_config](float* logits, size_t vocab_size) -> TokenAndProb {
return [&runtime_config](float* logits,
size_t vocab_size) HWY_ATTR -> TokenAndProb {
PROFILER_ZONE("Gen.Sample general");
Softmax(logits, vocab_size);
const int token = SampleTopK<kTopK>(logits, vocab_size, *runtime_config.gen,

View File

@ -28,13 +28,13 @@
#include "gemma/kv_cache.h"
#include "gemma/tokenizer.h"
#include "paligemma/image.h"
#include "util/allocator.h"
#include "util/allocator.h" // RowVectorBatch
#include "util/basics.h" // TokenAndProb
#include "util/threading.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h" // hwy::bfloat16_t
namespace gcpp {
using PromptTokens = hwy::Span<const int>;

View File

@ -108,10 +108,10 @@ struct CompressedLayer {
// do not yet support smaller compressed types, or require at least bf16. When
// weights are f32, we also want such tensors to be f32.
// If weights are complex, this is also complex.
using WeightF32OrBF16 = hwy::If<
hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
hwy::If<hwy::IsSame<Weight, double>(), double,
hwy::If<hwy::IsSame<Weight, float>(), float, hwy::bfloat16_t>>>;
using WeightF32OrBF16 =
hwy::If<hwy::IsSame<Weight, std::complex<double>>(), std::complex<double>,
hwy::If<hwy::IsSame<Weight, double>(), double,
hwy::If<IsF32<Weight>(), float, BF16>>>;
static constexpr size_t kHeads = TConfig::kHeads;
static constexpr size_t kKVHeads = TConfig::kKVHeads;
@ -363,9 +363,8 @@ struct CompressedWeights {
using Weight = typename TConfig::Weight;
using WeightF32OrBF16 = typename CompressedLayer<TConfig>::WeightF32OrBF16;
using WeightF32OrInputT =
hwy::If<hwy::IsSame<WeightF32OrBF16, hwy::bfloat16_t>(), EmbedderInputT,
WeightF32OrBF16>;
using WeightF32OrInputT = hwy::If<hwy::IsSame<WeightF32OrBF16, BF16>(),
EmbedderInputT, WeightF32OrBF16>;
MatPtrT<WeightF32OrInputT> embedder_input_embedding;
MatPtrT<WeightF32OrBF16> final_norm_scale;

View File

@ -840,7 +840,7 @@ class DotStats {
ASSERT_INSIDE(kPairwise, 4.5E-4, s_rels[kPairwise].GeometricMean(), 1.5E-2);
// Extremely high error on aarch64.
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 1250.f);
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_rels[kPairwise].Max(), 2E3f);
}
// Backward relative error, lower is better.

View File

@ -19,9 +19,10 @@
#include <stddef.h>
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h" // PerClusterPools
#include "util/threading.h"
#include "hwy/aligned_allocator.h" // IWYU pragma: export
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/contrib/thread_pool/thread_pool.h" // IWYU pragma: export
#include "hwy/per_target.h"
namespace gcpp {

View File

@ -28,7 +28,7 @@
#include <type_traits> // std::enable_if_t
#include "compression/compress.h"
#include "util/allocator.h" // TokenAndProb
#include "util/basics.h" // TokenAndProb
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h"
@ -44,6 +44,7 @@
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/sum-inl.h"
#include "hwy/contrib/algo/transform-inl.h"
#include "hwy/contrib/math/math-inl.h"
#include "hwy/profiler.h" // also uses SIMD
@ -507,183 +508,6 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size);
}
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 sums.
struct SumKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using State = double;
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
sum0 = hn::Add(sum0, w0);
sum1 = hn::Add(sum1, w1);
sum2 = hn::Add(sum2, w2);
sum3 = hn::Add(sum3, w3);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
// Reduce to two f32 sums so we can promote them to four f64 vectors.
VF sum02, sum13;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
VF unused0 = hn::Zero(df);
VF unused1 = hn::Zero(df);
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
} else {
// If not native, the multiplication costs extra, so convert to f32.
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fe1 = hn::PromoteEvenTo(df, w1);
const VF fe2 = hn::PromoteEvenTo(df, w2);
const VF fe3 = hn::PromoteEvenTo(df, w3);
const VF fo0 = hn::PromoteOddTo(df, w0);
const VF fo1 = hn::PromoteOddTo(df, w1);
const VF fo2 = hn::PromoteOddTo(df, w2);
const VF fo3 = hn::PromoteOddTo(df, w3);
const VF fe01 = hn::Add(fe0, fe1);
const VF fe23 = hn::Add(fe2, fe3);
const VF fo01 = hn::Add(fo0, fo1);
const VF fo23 = hn::Add(fo2, fo3);
sum02 = hn::Add(fe01, fe23);
sum13 = hn::Add(fo01, fo23);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, sum02);
const VS d1 = hn::PromoteUpperTo(ds, sum02);
const VS d2 = hn::PromoteLowerTo(ds, sum13);
const VS d3 = hn::PromoteUpperTo(ds, sum13);
sum0 = hn::Add(sum0, d0);
sum1 = hn::Add(sum1, d1);
sum2 = hn::Add(sum2, d2);
sum3 = hn::Add(sum3, d3);
}
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
VR& comp0) const {
sum0 = hn::Add(sum0, w0);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
VS& extra0) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
VF f0;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
} else {
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fo0 = hn::PromoteOddTo(df, w0);
f0 = hn::Add(fe0, fo0);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, f0);
const VS d1 = hn::PromoteUpperTo(ds, f0);
sum0 = hn::Add(sum0, d0);
extra0 = hn::Add(extra0, d1);
}
template <class DState, class VS = hn::Vec<DState>>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& extra0, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, extra0); // from Update1
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
// instead of FastTwoSums because the magnitude of the initial sum is not
// always greater than the next input, and this does actually change the e2e
// generation results. Note that Kahan summation differs in that it first adds
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains.
//
// This about as accurate as SumKernelDouble but slower, hence we only use this
// if f64 is not supported on this target.
struct SumKernelCascaded {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
VF& comp3) const {
VF serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, sum0, w0, serr0);
sum1 = TwoSums(df, sum1, w1, serr1);
sum2 = TwoSums(df, sum2, w2, serr2);
sum3 = TwoSums(df, sum3, w3, serr3);
comp0 = hn::Add(comp0, serr0);
comp1 = hn::Add(comp1, serr1);
comp2 = hn::Add(comp2, serr2);
comp3 = hn::Add(comp3, serr3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF serr0;
sum0 = TwoSums(df, sum0, w0, serr0);
comp0 = hn::Add(comp0, serr0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};
using SumKernelDefault =
hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;
template <class D, typename VT>
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
const hn::Repartition<Raw, D> d_raw;
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
}
// See below for a specialized version for top-1 sampling.
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) {

217
ops/sum-inl.h Normal file
View File

@ -0,0 +1,217 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <stddef.h>
#include "hwy/base.h"
// Include guard for SIMD code.
#if defined(THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_SUM_TOGGLE
#endif
#include "compression/compress-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// f64 Add, called for f32 inputs promoted to f64. Runs at about half the speed
// of f32 sums.
struct SumKernelDouble {
// Only `CompressTraits<float>` can `Decompress2` to `double`, so both have
// to be `float` in order to have `Raw = double`. Note that if either type is
// smaller than `float`, we may demote the other type from `float` to `BF16`.
template <typename VT, typename WT>
using Raw = hwy::If<IsF32<VT>() && IsF32<WT>(), double, BF16>;
using State = double;
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update4(DRaw /*dd*/, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VR& sum0, VR& sum1,
VR& sum2, VR& sum3, VR&, VR&, VR&, VR&) const {
sum0 = hn::Add(sum0, w0);
sum1 = hn::Add(sum1, w1);
sum2 = hn::Add(sum2, w2);
sum3 = hn::Add(sum3, w3);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update4(DRaw dr, const VR w0, const VR w1, const VR w2,
const VR w3, VR, VR, VR, VR, VS& sum0, VS& sum1,
VS& sum2, VS& sum3, VS&, VS&, VS&, VS&) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
// Reduce to two f32 sums so we can promote them to four f64 vectors.
VF sum02, sum13;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
const VF prod0 = hn::WidenMulPairwiseAdd(df, w0, k1);
const VF prod1 = hn::WidenMulPairwiseAdd(df, w1, k1);
// Fuse WidenMulPairwiseAdd plus Add into ReorderWidenMulAccumulate.
VF unused0 = hn::Zero(df);
VF unused1 = hn::Zero(df);
sum02 = hn::ReorderWidenMulAccumulate(df, w2, k1, prod0, unused0);
sum13 = hn::ReorderWidenMulAccumulate(df, w3, k1, prod1, unused1);
} else {
// If not native, the multiplication costs extra, so convert to f32.
// PromoteEvenTo is cheaper than PromoteUpperTo especially on `SVE`.
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fe1 = hn::PromoteEvenTo(df, w1);
const VF fe2 = hn::PromoteEvenTo(df, w2);
const VF fe3 = hn::PromoteEvenTo(df, w3);
const VF fo0 = hn::PromoteOddTo(df, w0);
const VF fo1 = hn::PromoteOddTo(df, w1);
const VF fo2 = hn::PromoteOddTo(df, w2);
const VF fo3 = hn::PromoteOddTo(df, w3);
const VF fe01 = hn::Add(fe0, fe1);
const VF fe23 = hn::Add(fe2, fe3);
const VF fo01 = hn::Add(fo0, fo1);
const VF fo23 = hn::Add(fo2, fo3);
sum02 = hn::Add(fe01, fe23);
sum13 = hn::Add(fo01, fo23);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, sum02);
const VS d1 = hn::PromoteUpperTo(ds, sum02);
const VS d2 = hn::PromoteLowerTo(ds, sum13);
const VS d3 = hn::PromoteUpperTo(ds, sum13);
sum0 = hn::Add(sum0, d0);
sum1 = hn::Add(sum1, d1);
sum2 = hn::Add(sum2, d2);
sum3 = hn::Add(sum3, d3);
}
// Raw = double
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_F64_D(DRaw)>
HWY_INLINE void Update1(DRaw /*dd*/, const VR w0, const VR v0, VR& sum0,
VR& comp0) const {
sum0 = hn::Add(sum0, w0);
}
// Raw = BF16
template <class DRaw, class VR = hn::Vec<DRaw>, HWY_IF_BF16_D(DRaw),
class DS = hn::Repartition<double, DRaw>, class VS = hn::Vec<DS>>
HWY_INLINE void Update1(DRaw dr, const VR w0, VR, VS& sum0,
VS& extra0) const {
const hn::Repartition<float, DRaw> df;
using VF = hn::Vec<decltype(df)>;
VF f0;
if constexpr (HWY_NATIVE_DOT_BF16) {
const VR k1 = hn::Set(dr, hwy::ConvertScalarTo<BF16>(1.0f));
f0 = hn::WidenMulPairwiseAdd(df, w0, k1);
} else {
const VF fe0 = hn::PromoteEvenTo(df, w0);
const VF fo0 = hn::PromoteOddTo(df, w0);
f0 = hn::Add(fe0, fo0);
}
const DS ds;
const VS d0 = hn::PromoteLowerTo(ds, f0);
const VS d1 = hn::PromoteUpperTo(ds, f0);
sum0 = hn::Add(sum0, d0);
extra0 = hn::Add(extra0, d1);
}
template <class DState, class VS = hn::Vec<DState>>
HWY_INLINE float Reduce(DState dd, VS& sum0, VS& sum1, VS& sum2, VS& sum3,
VS& extra0, VS&, VS&, VS&) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, extra0); // from Update1
sum0 = hn::Add(sum0, sum2);
return static_cast<float>(hn::ReduceSum(dd, sum0));
}
};
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
// instead of FastTwoSums because the magnitude of the initial sum is not
// always greater than the next input, and this does actually change the e2e
// generation results. Note that Kahan summation differs in that it first adds
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains.
//
// This about as accurate as SumKernelDouble but slower, hence we only use this
// if f64 is not supported on this target.
struct SumKernelCascaded {
template <typename VT, typename WT>
using Raw = float;
using State = float;
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
VF& comp3) const {
VF serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, sum0, w0, serr0);
sum1 = TwoSums(df, sum1, w1, serr1);
sum2 = TwoSums(df, sum2, w2, serr2);
sum3 = TwoSums(df, sum3, w3, serr3);
comp0 = hn::Add(comp0, serr0);
comp1 = hn::Add(comp1, serr1);
comp2 = hn::Add(comp2, serr2);
comp3 = hn::Add(comp3, serr3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF serr0;
sum0 = TwoSums(df, sum0, w0, serr0);
comp0 = hn::Add(comp0, serr0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};
using SumKernelDefault =
hwy::If<HWY_HAVE_FLOAT64, SumKernelDouble, SumKernelCascaded>;
template <class D, typename VT>
HWY_INLINE float Sum(D d, const VT* HWY_RESTRICT vec, size_t num) {
using Raw = hwy::If<HWY_HAVE_FLOAT64, double, float>;
const hn::Repartition<Raw, D> d_raw;
return DecompressAndCall(d_raw, MakeSpan(vec, num), SumKernelDefault());
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // NOLINT

View File

@ -11,7 +11,7 @@ cc_library(
name = "image",
srcs = ["image.cc"],
hdrs = ["image.h"],
deps = ["@hwy//:hwy"],
deps = ["@highway//:hwy"],
)
cc_test(
@ -39,7 +39,7 @@ cc_test(
"//:common",
"//:gemma_lib",
"//:tokenizer",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
],
)

View File

@ -19,30 +19,11 @@
#include <stddef.h>
#include <stdint.h>
#include "hwy/aligned_allocator.h"
#include "hwy/aligned_allocator.h" // IWYU pragma: export
#include "hwy/base.h"
#if HWY_IS_MSAN
#include <sanitizer/msan_interface.h>
#endif
namespace gcpp {
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
__msan_check_mem_is_initialized(ptr, size);
#else
(void)ptr;
(void)size;
#endif
}
// Shared between gemma.h and ops-inl.h.
struct TokenAndProb {
int token;
float prob;
};
using ByteStorageT = hwy::AlignedFreeUniquePtr<uint8_t[]>;
template <typename T>

View File

@ -82,6 +82,9 @@ class AppArgs : public ArgsBase<AppArgs> {
visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", -1, "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
// These can be used to partition CPU sockets/packages and their
// clusters/CCXs across several program instances. The default is to use
// all available resources.
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{0},

49
util/basics.h Normal file
View File

@ -0,0 +1,49 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_
// IWYU pragma: begin_exports
#include <stddef.h>
#include <stdint.h>
#include "hwy/base.h" // HWY_IS_MSAN
// IWYU pragma: end_exports
#if HWY_IS_MSAN
#include <sanitizer/msan_interface.h>
#endif
namespace gcpp {
static inline void MaybeCheckInitialized(const void* ptr, size_t size) {
#if HWY_IS_MSAN
__msan_check_mem_is_initialized(ptr, size);
#else
(void)ptr;
(void)size;
#endif
}
// Shared between gemma.h and ops-inl.h.
struct TokenAndProb {
int token;
float prob;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_BASICS_H_

View File

@ -381,7 +381,7 @@ class BoundedTopology {
LPS enabled_lps; // LPs not disabled via OS, taskset, or numactl.
bool missing_cluster = false;
if (HWY_LIKELY(have_threading_support)) {
if (HWY_LIKELY(have_threading_support && !topology_.packages.empty())) {
(void)GetThreadAffinity(enabled_lps); // failure = all disabled
// No effect if topology is unknown or `enabled_lps` is empty.