mirror of https://github.com/google/gemma.cpp.git
Major refactor of allocator/args:
use new ThreadingContext2 instead of monostate/init in each frontend Add ThreadingArgs(replaces AppArgs) backprop: use Packed() accessor and MakePacked factory and row-based access to allow for stride compress_weights: remove, moving to py-only exporter instead Move MatPtr to mat.h and revise interface: - Generic MatOwner - rename accessors to Packed* - support stride/row accessors, fix RowPtr stride Add TypeBits(Type) Move GenerateMat to test_util-inl for sharing between matmul test/bench Move internal init to gemma.cc to avoid duplication Rename GemmaEnv model_ to gemma_ for disambiguating vs upcoming ModelStorage Remove --compressed_weights, use --weights instead. tensor_index: add ExtentsFromInfo and TensorIndexLLM/Img Allocator: use normal unique_ptr for AllocBytes so users can call directly threading: use -> because AlignedPtr no longer assumes arrays PiperOrigin-RevId: 745918637
This commit is contained in:
parent
bef91a3f03
commit
8532da47f7
220
BUILD.bazel
220
BUILD.bazel
|
|
@ -19,7 +19,10 @@ license(
|
||||||
# Dual-licensed Apache 2 and 3-clause BSD.
|
# Dual-licensed Apache 2 and 3-clause BSD.
|
||||||
licenses(["notice"])
|
licenses(["notice"])
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files([
|
||||||
|
"LICENSE",
|
||||||
|
".github/workflows/build.yml",
|
||||||
|
])
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "basics",
|
name = "basics",
|
||||||
|
|
@ -29,6 +32,16 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "args",
|
||||||
|
hdrs = ["util/args.h"],
|
||||||
|
deps = [
|
||||||
|
":basics",
|
||||||
|
"//compression:io", # Path
|
||||||
|
"@highway//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Split from :threading to break a circular dependency with :allocator.
|
# Split from :threading to break a circular dependency with :allocator.
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "topology",
|
name = "topology",
|
||||||
|
|
@ -59,6 +72,7 @@ cc_library(
|
||||||
hdrs = ["util/threading.h"],
|
hdrs = ["util/threading.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
|
":args",
|
||||||
":basics",
|
":basics",
|
||||||
":topology",
|
":topology",
|
||||||
# Placeholder for container detection, do not remove
|
# Placeholder for container detection, do not remove
|
||||||
|
|
@ -68,14 +82,26 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "threading_context",
|
||||||
|
srcs = ["util/threading_context.cc"],
|
||||||
|
hdrs = ["util/threading_context.h"],
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":args",
|
||||||
|
":basics",
|
||||||
|
":threading",
|
||||||
|
":topology",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "threading_test",
|
name = "threading_test",
|
||||||
srcs = ["util/threading_test.cc"],
|
srcs = ["util/threading_test.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
|
||||||
":basics",
|
":basics",
|
||||||
":threading",
|
":threading_context",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@highway//:auto_tune",
|
"@highway//:auto_tune",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
|
@ -97,6 +123,65 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "common",
|
||||||
|
srcs = [
|
||||||
|
"gemma/common.cc",
|
||||||
|
"gemma/configs.cc",
|
||||||
|
"gemma/tensor_index.cc",
|
||||||
|
],
|
||||||
|
hdrs = [
|
||||||
|
"gemma/common.h",
|
||||||
|
"gemma/configs.h",
|
||||||
|
"gemma/tensor_index.h",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":basics",
|
||||||
|
"//compression:fields",
|
||||||
|
"//compression:sfp",
|
||||||
|
"@highway//:hwy", # base.h
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "configs_test",
|
||||||
|
srcs = ["gemma/configs_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":common",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"@highway//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_test(
|
||||||
|
name = "tensor_index_test",
|
||||||
|
srcs = ["gemma/tensor_index_test.cc"],
|
||||||
|
deps = [
|
||||||
|
":basics",
|
||||||
|
":common",
|
||||||
|
":weights",
|
||||||
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
|
"//compression:compress",
|
||||||
|
"@highway//:hwy",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "mat",
|
||||||
|
srcs = ["util/mat.cc"],
|
||||||
|
hdrs = ["util/mat.h"],
|
||||||
|
deps = [
|
||||||
|
":allocator",
|
||||||
|
":basics",
|
||||||
|
":common",
|
||||||
|
":threading_context",
|
||||||
|
"//compression:fields",
|
||||||
|
"//compression:sfp",
|
||||||
|
"@highway//:hwy",
|
||||||
|
"@highway//:profiler",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# For building all tests in one command, so we can test several.
|
# For building all tests in one command, so we can test several.
|
||||||
test_suite(
|
test_suite(
|
||||||
name = "ops_tests",
|
name = "ops_tests",
|
||||||
|
|
@ -123,8 +208,9 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":basics",
|
":basics",
|
||||||
|
":mat",
|
||||||
":threading",
|
":threading",
|
||||||
":topology",
|
":threading_context",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:algo",
|
"@highway//:algo",
|
||||||
"@highway//:bit_set",
|
"@highway//:bit_set",
|
||||||
|
|
@ -148,10 +234,9 @@ cc_test(
|
||||||
tags = ["ops_tests"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":app",
|
|
||||||
":ops",
|
":ops",
|
||||||
":test_util",
|
":test_util",
|
||||||
":threading",
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:test_util",
|
"//compression:test_util",
|
||||||
|
|
@ -174,13 +259,13 @@ cc_test(
|
||||||
tags = ["ops_tests"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":app",
|
":basics",
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
":ops",
|
":ops",
|
||||||
":test_util",
|
":test_util",
|
||||||
":threading",
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark", #buildcleaner: keep
|
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||||
|
|
@ -196,6 +281,7 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["ops_tests"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
|
":mat",
|
||||||
":ops",
|
":ops",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
|
@ -214,12 +300,13 @@ cc_test(
|
||||||
# for test_suite.
|
# for test_suite.
|
||||||
tags = ["ops_tests"],
|
tags = ["ops_tests"],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
|
||||||
":basics",
|
":basics",
|
||||||
|
":mat",
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
"//compression:test_util",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -238,12 +325,12 @@ cc_test(
|
||||||
"ops_tests", # for test_suite.
|
"ops_tests", # for test_suite.
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
|
||||||
":basics",
|
":basics",
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading_context",
|
||||||
"@googletest//:gtest_main", # buildcleaner: keep
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
|
"//compression:test_util",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
|
|
@ -252,55 +339,13 @@ cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "common",
|
|
||||||
srcs = [
|
|
||||||
"gemma/common.cc",
|
|
||||||
"gemma/configs.cc",
|
|
||||||
"gemma/tensor_index.cc",
|
|
||||||
],
|
|
||||||
hdrs = [
|
|
||||||
"gemma/common.h",
|
|
||||||
"gemma/configs.h",
|
|
||||||
"gemma/tensor_index.h",
|
|
||||||
],
|
|
||||||
deps = [
|
|
||||||
":basics",
|
|
||||||
"//compression:fields",
|
|
||||||
"//compression:sfp",
|
|
||||||
"@highway//:hwy", # base.h
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "configs_test",
|
|
||||||
srcs = ["gemma/configs_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":common",
|
|
||||||
"@googletest//:gtest_main",
|
|
||||||
"@highway//:hwy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_test(
|
|
||||||
name = "tensor_index_test",
|
|
||||||
srcs = ["gemma/tensor_index_test.cc"],
|
|
||||||
deps = [
|
|
||||||
":basics",
|
|
||||||
":common",
|
|
||||||
":weights",
|
|
||||||
"@googletest//:gtest_main",
|
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:hwy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "weights",
|
name = "weights",
|
||||||
srcs = ["gemma/weights.cc"],
|
srcs = ["gemma/weights.cc"],
|
||||||
hdrs = ["gemma/weights.h"],
|
hdrs = ["gemma/weights.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
"//compression:blob_store",
|
"//compression:blob_store",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
|
|
@ -361,16 +406,17 @@ cc_library(
|
||||||
":basics",
|
":basics",
|
||||||
":common",
|
":common",
|
||||||
":ops",
|
":ops",
|
||||||
|
":mat",
|
||||||
":tokenizer",
|
":tokenizer",
|
||||||
":kv_cache",
|
":kv_cache",
|
||||||
":weights",
|
":weights",
|
||||||
":threading",
|
":threading",
|
||||||
"//compression:compress",
|
":threading_context",
|
||||||
|
# Placeholder for internal dep, do not remove.,
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:bit_set",
|
|
||||||
"@highway//:nanobenchmark", # timer
|
"@highway//:nanobenchmark", # timer
|
||||||
"@highway//:profiler",
|
"@highway//:profiler",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -390,25 +436,14 @@ cc_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "args",
|
name = "gemma_args",
|
||||||
hdrs = ["util/args.h"],
|
hdrs = ["gemma/gemma_args.h"],
|
||||||
deps = [
|
|
||||||
":basics",
|
|
||||||
"//compression:io",
|
|
||||||
"@highway//:hwy",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_library(
|
|
||||||
name = "app",
|
|
||||||
hdrs = ["util/app.h"],
|
|
||||||
deps = [
|
deps = [
|
||||||
":args",
|
":args",
|
||||||
":basics",
|
":basics",
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -420,20 +455,15 @@ cc_library(
|
||||||
srcs = ["evals/benchmark_helper.cc"],
|
srcs = ["evals/benchmark_helper.cc"],
|
||||||
hdrs = ["evals/benchmark_helper.h"],
|
hdrs = ["evals/benchmark_helper.h"],
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
|
||||||
":args",
|
|
||||||
":common",
|
|
||||||
":cross_entropy",
|
":cross_entropy",
|
||||||
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":kv_cache",
|
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading_context",
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"@google_benchmark//:benchmark",
|
"@google_benchmark//:benchmark",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@highway//:topology",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -451,7 +481,7 @@ cc_test(
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
|
|
@ -470,8 +500,7 @@ cc_test(
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":tokenizer",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"@googletest//:gtest_main",
|
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
],
|
],
|
||||||
|
|
@ -481,14 +510,13 @@ cc_binary(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = ["gemma/run.cc"],
|
srcs = ["gemma/run.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":app",
|
|
||||||
":args",
|
":args",
|
||||||
":benchmark_helper",
|
":benchmark_helper",
|
||||||
":common",
|
":common",
|
||||||
|
":gemma_args",
|
||||||
":gemma_lib",
|
":gemma_lib",
|
||||||
":ops",
|
":ops",
|
||||||
":threading",
|
":threading_context",
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"//paligemma:image",
|
"//paligemma:image",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -594,10 +622,10 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
":ops",
|
":ops",
|
||||||
":prompt",
|
":prompt",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:dot",
|
"@highway//:dot",
|
||||||
"@highway//:hwy", # base.h
|
"@highway//:hwy", # base.h
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -614,9 +642,9 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
":prompt",
|
":prompt",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -631,11 +659,11 @@ cc_test(
|
||||||
deps = [
|
deps = [
|
||||||
":backprop_scalar",
|
":backprop_scalar",
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
":prompt",
|
":prompt",
|
||||||
":sampler",
|
":sampler",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
@ -652,17 +680,16 @@ cc_test(
|
||||||
"mem": "28g",
|
"mem": "28g",
|
||||||
},
|
},
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
|
||||||
":backprop",
|
":backprop",
|
||||||
":backprop_scalar",
|
":backprop_scalar",
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
":ops",
|
":ops",
|
||||||
":prompt",
|
":prompt",
|
||||||
":sampler",
|
":sampler",
|
||||||
":threading",
|
":threading_context",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:compress",
|
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
|
|
@ -676,6 +703,7 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
":allocator",
|
":allocator",
|
||||||
":common",
|
":common",
|
||||||
|
":mat",
|
||||||
":weights",
|
":weights",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -685,9 +713,7 @@ cc_library(
|
||||||
|
|
||||||
cc_test(
|
cc_test(
|
||||||
name = "optimize_test",
|
name = "optimize_test",
|
||||||
srcs = [
|
srcs = ["backprop/optimize_test.cc"],
|
||||||
"backprop/optimize_test.cc",
|
|
||||||
],
|
|
||||||
exec_properties = {
|
exec_properties = {
|
||||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||||
"mem": "28g",
|
"mem": "28g",
|
||||||
|
|
@ -704,7 +730,7 @@ cc_test(
|
||||||
":sampler",
|
":sampler",
|
||||||
":threading",
|
":threading",
|
||||||
":weights",
|
":weights",
|
||||||
"@googletest//:gtest_main",
|
"@googletest//:gtest_main", # buildcleaner: keep
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"@highway//:thread_pool",
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -75,6 +75,7 @@ set(SOURCES
|
||||||
gemma/common.h
|
gemma/common.h
|
||||||
gemma/configs.cc
|
gemma/configs.cc
|
||||||
gemma/configs.h
|
gemma/configs.h
|
||||||
|
gemma/gemma_args.h
|
||||||
gemma/gemma-inl.h
|
gemma/gemma-inl.h
|
||||||
gemma/gemma.cc
|
gemma/gemma.cc
|
||||||
gemma/gemma.h
|
gemma/gemma.h
|
||||||
|
|
@ -102,15 +103,17 @@ set(SOURCES
|
||||||
paligemma/image.h
|
paligemma/image.h
|
||||||
util/allocator.cc
|
util/allocator.cc
|
||||||
util/allocator.h
|
util/allocator.h
|
||||||
util/app.h
|
|
||||||
util/args.h
|
|
||||||
util/basics.h
|
util/basics.h
|
||||||
|
util/mat.cc
|
||||||
|
util/mat.h
|
||||||
util/test_util.h
|
util/test_util.h
|
||||||
util/threading.cc
|
util/threading.cc
|
||||||
util/threading.h
|
util/threading.h
|
||||||
|
util/threading_context.cc
|
||||||
|
util/threading_context.h
|
||||||
util/topology.cc
|
util/topology.cc
|
||||||
util/topology.h
|
util/topology.h
|
||||||
)
|
)
|
||||||
|
|
||||||
if(NOT CMAKE_BUILD_TYPE)
|
if(NOT CMAKE_BUILD_TYPE)
|
||||||
set(CMAKE_BUILD_TYPE "Release")
|
set(CMAKE_BUILD_TYPE "Release")
|
||||||
|
|
@ -197,8 +200,5 @@ endif() # GEMMA_ENABLE_TESTS
|
||||||
|
|
||||||
## Tools
|
## Tools
|
||||||
|
|
||||||
add_executable(compress_weights compression/compress_weights.cc)
|
|
||||||
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
|
||||||
|
|
||||||
add_executable(migrate_weights compression/migrate_weights.cc)
|
add_executable(migrate_weights compression/migrate_weights.cc)
|
||||||
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)
|
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)
|
||||||
|
|
|
||||||
|
|
@ -20,24 +20,30 @@
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h" // MatStorageT
|
|
||||||
#include "gemma/configs.h" // ModelConfig
|
#include "gemma/configs.h" // ModelConfig
|
||||||
|
#include "util/mat.h" // MatStorageT
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ForwardLayer {
|
struct ForwardLayer {
|
||||||
ForwardLayer(const LayerConfig& config, size_t seq_len)
|
ForwardLayer(const LayerConfig& config, size_t seq_len)
|
||||||
: input("input", seq_len, config.model_dim),
|
: input(MakePacked<T>("input", seq_len, config.model_dim)),
|
||||||
pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
|
pre_att_rms_out(
|
||||||
qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
|
MakePacked<T>("pre_att_rms_out", seq_len, config.model_dim)),
|
||||||
att("att", seq_len * config.heads, seq_len),
|
qkv(MakePacked<T>("qkv", seq_len * (config.heads + 2), config.qkv_dim)),
|
||||||
att_out("att_out", seq_len * config.heads, config.qkv_dim),
|
att(MakePacked<T>("att", seq_len * config.heads, seq_len)),
|
||||||
att_post1("att_post1", seq_len, config.model_dim),
|
att_out(
|
||||||
attention_out("attention_out", seq_len, config.model_dim),
|
MakePacked<T>("att_out", seq_len * config.heads, config.qkv_dim)),
|
||||||
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
|
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
|
||||||
ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
|
attention_out(
|
||||||
ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
|
MakePacked<T>("attention_out", seq_len, config.model_dim)),
|
||||||
|
bf_pre_ffw_rms_out(
|
||||||
|
MakePacked<T>("bf_preFF_rms_out", seq_len, config.model_dim)),
|
||||||
|
ffw_hidden(
|
||||||
|
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
|
||||||
|
ffw_hidden_gated(
|
||||||
|
MakePacked<T>("ffw_hidden_gated", seq_len, config.ff_hidden_dim)),
|
||||||
layer_config(config) {}
|
layer_config(config) {}
|
||||||
|
|
||||||
MatStorageT<T> input;
|
MatStorageT<T> input;
|
||||||
|
|
@ -56,12 +62,12 @@ struct ForwardLayer {
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ForwardPass {
|
struct ForwardPass {
|
||||||
ForwardPass(const ModelConfig& config)
|
ForwardPass(const ModelConfig& config)
|
||||||
: final_layer_output("final_layer_output", config.seq_len,
|
: final_layer_output(
|
||||||
config.model_dim),
|
MakePacked<T>("fin_layer_out", config.seq_len, config.model_dim)),
|
||||||
final_norm_output("final_norm_output", config.seq_len,
|
final_norm_output(
|
||||||
config.model_dim),
|
MakePacked<T>("fin_norm_out", config.seq_len, config.model_dim)),
|
||||||
logits("logits", config.seq_len, config.vocab_size),
|
logits(MakePacked<T>("logits", config.seq_len, config.vocab_size)),
|
||||||
probs("probs", config.seq_len, config.vocab_size),
|
probs(MakePacked<T>("probs", config.seq_len, config.vocab_size)),
|
||||||
weights_config(config) {
|
weights_config(config) {
|
||||||
for (const auto& layer_config : config.layer_configs) {
|
for (const auto& layer_config : config.layer_configs) {
|
||||||
layers.emplace_back(layer_config, config.seq_len);
|
layers.emplace_back(layer_config, config.seq_len);
|
||||||
|
|
|
||||||
|
|
@ -128,7 +128,7 @@ static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
|
||||||
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
|
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void RMSNormVJP(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP(
|
||||||
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
|
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
|
||||||
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
|
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
|
||||||
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
|
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
|
||||||
|
|
@ -153,10 +153,9 @@ static HWY_NOINLINE void RMSNormVJP(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static HWY_NOINLINE void InputEmbeddingVJP(
|
static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP(
|
||||||
const float* weights, const std::vector<int>& prompt,
|
const float* weights, const std::vector<int>& prompt, const float scaling,
|
||||||
const float scaling, const float* HWY_RESTRICT v,
|
const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) {
|
||||||
float* HWY_RESTRICT grad, size_t model_dim) {
|
|
||||||
HWY_ASSERT(!prompt.empty());
|
HWY_ASSERT(!prompt.empty());
|
||||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||||
int token = prompt[pos];
|
int token = prompt[pos];
|
||||||
|
|
@ -182,17 +181,18 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
|
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
|
||||||
HWY_ASSERT(num_tokens <= seq_len);
|
HWY_ASSERT(num_tokens <= seq_len);
|
||||||
|
|
||||||
MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
|
MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(),
|
||||||
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
|
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
|
||||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
|
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t hidden_offset = pos * ff_hidden_dim * 2;
|
const size_t hidden_offset = pos * ff_hidden_dim * 2;
|
||||||
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
|
const float* HWY_RESTRICT f_out =
|
||||||
|
forward.ffw_hidden.Packed() + hidden_offset;
|
||||||
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
|
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
|
||||||
const float* HWY_RESTRICT b_out_gated =
|
const float* HWY_RESTRICT b_out_gated =
|
||||||
backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
|
backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim;
|
||||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
|
float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset;
|
||||||
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
|
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
|
|
@ -206,38 +206,39 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
MatMulVJP(weights.gating_einsum_w.Packed(),
|
||||||
backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
|
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
|
||||||
num_tokens, grad.gating_einsum_w.data(),
|
model_dim, ff_hidden_dim * 2, num_tokens,
|
||||||
backward.bf_pre_ffw_rms_out.data(), pool);
|
grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(),
|
||||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
|
||||||
backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
|
|
||||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
|
||||||
pool);
|
pool);
|
||||||
|
RMSNormVJP(
|
||||||
|
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
|
||||||
|
backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens,
|
||||||
|
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
AddFrom(next_layer_grad + pos * model_dim,
|
AddFrom(next_layer_grad + pos * model_dim,
|
||||||
backward.attention_out.data() + pos * model_dim, model_dim);
|
backward.attention_out.Packed() + pos * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
backward.qkv.ZeroInit();
|
ZeroInit(backward.qkv);
|
||||||
|
|
||||||
MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
MultiHeadMatMulVJP(
|
||||||
backward.attention_out.data(), heads, qkv_dim, model_dim,
|
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
|
||||||
num_tokens, grad.attn_vec_einsum_w.data(),
|
backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens,
|
||||||
backward.att_out.data(), pool);
|
grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool);
|
||||||
|
|
||||||
for (size_t head = 0; head < heads; ++head) {
|
for (size_t head = 0; head < heads; ++head) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
|
||||||
const float* HWY_RESTRICT b_att_out =
|
const float* HWY_RESTRICT b_att_out =
|
||||||
backward.att_out.data() + (pos * heads + head) * qkv_dim;
|
backward.att_out.Packed() + (pos * heads + head) * qkv_dim;
|
||||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||||
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
|
const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs;
|
||||||
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
|
float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs;
|
||||||
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
|
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
|
||||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
|
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
|
||||||
}
|
}
|
||||||
|
|
@ -247,8 +248,8 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||||
for (size_t head = 0; head < heads; ++head) {
|
for (size_t head = 0; head < heads; ++head) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
|
||||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
|
||||||
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -257,13 +258,13 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
|
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
|
||||||
const size_t aoffs = head * seq_len + pos * heads * seq_len;
|
const size_t aoffs = head * seq_len + pos * heads * seq_len;
|
||||||
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
|
const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs;
|
||||||
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
|
const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs;
|
||||||
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
|
float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs;
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
|
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||||
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
|
const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs;
|
||||||
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
|
float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs;
|
||||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
|
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
|
||||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
|
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
|
||||||
}
|
}
|
||||||
|
|
@ -272,28 +273,30 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||||
|
|
||||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||||
float* HWY_RESTRICT b_kv =
|
float* HWY_RESTRICT b_kv =
|
||||||
backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
|
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||||
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
|
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t head = 0; head < heads; ++head) {
|
for (size_t head = 0; head < heads; ++head) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
float* HWY_RESTRICT b_q =
|
float* HWY_RESTRICT b_q =
|
||||||
backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
|
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
|
||||||
MulByConst(query_scale, b_q, qkv_dim);
|
MulByConst(query_scale, b_q, qkv_dim);
|
||||||
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
|
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
|
||||||
backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
||||||
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
|
grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(),
|
||||||
RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
pool);
|
||||||
backward.pre_att_rms_out.data(), model_dim, num_tokens,
|
RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
|
||||||
grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
|
backward.pre_att_rms_out.Packed(), model_dim, num_tokens,
|
||||||
|
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
|
||||||
|
pool);
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
AddFrom(backward.attention_out.data() + pos * model_dim,
|
AddFrom(backward.attention_out.Packed() + pos * model_dim,
|
||||||
backward.input.data() + pos * model_dim, model_dim);
|
backward.input.Packed() + pos * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -353,47 +356,48 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
|
||||||
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
||||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
const size_t num_tokens = prompt.tokens.size() - 1;
|
||||||
|
|
||||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
|
||||||
kVocabSize);
|
kVocabSize);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
|
SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize,
|
||||||
backward.logits.data() + pos * kVocabSize,
|
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
|
||||||
kVocabSize);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if (config.final_cap > 0.0f) {
|
if (config.final_cap > 0.0f) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
|
SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize,
|
||||||
backward.logits.data() + pos * kVocabSize, kVocabSize);
|
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJP(weights.embedder_input_embedding.data(),
|
MatMulVJP(weights.embedder_input_embedding.Packed(),
|
||||||
forward.final_norm_output.data(), backward.logits.data(), model_dim,
|
forward.final_norm_output.Packed(), backward.logits.Packed(),
|
||||||
kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
|
model_dim, kVocabSize, num_tokens,
|
||||||
backward.final_norm_output.data(), pool);
|
grad.embedder_input_embedding.Packed(),
|
||||||
|
backward.final_norm_output.Packed(), pool);
|
||||||
|
|
||||||
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
RMSNormVJP(weights.final_norm_scale.Packed(),
|
||||||
backward.final_norm_output.data(), model_dim, num_tokens,
|
forward.final_layer_output.Packed(),
|
||||||
grad.final_norm_scale.data(), backward.final_layer_output.data(),
|
backward.final_norm_output.Packed(), model_dim, num_tokens,
|
||||||
pool);
|
grad.final_norm_scale.Packed(),
|
||||||
|
backward.final_layer_output.Packed(), pool);
|
||||||
|
|
||||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||||
auto layer_config = config.layer_configs[layer];
|
auto layer_config = config.layer_configs[layer];
|
||||||
// TODO(szabadka) Implement Griffin layer vjp.
|
// TODO(szabadka) Implement Griffin layer vjp.
|
||||||
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
|
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
|
||||||
float* next_layer_grad = layer + 1 < kLayers
|
float* next_layer_grad = layer + 1 < kLayers
|
||||||
? backward.layers[layer + 1].input.data()
|
? backward.layers[layer + 1].input.Packed()
|
||||||
: backward.final_layer_output.data();
|
: backward.final_layer_output.Packed();
|
||||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
|
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
|
||||||
inv_timescale, pool);
|
inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens,
|
||||||
kEmbScaling, backward.layers[0].input.data(),
|
kEmbScaling, backward.layers[0].input.Packed(),
|
||||||
grad.embedder_input_embedding.data(), model_dim);
|
grad.embedder_input_embedding.Packed(), model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,8 @@
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/allocator.h"
|
#include "util/mat.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/allocator.h"
|
#include "util/mat.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
|
||||||
|
|
@ -211,62 +211,65 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
||||||
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
|
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
|
||||||
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
|
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
|
||||||
|
|
||||||
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
|
MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy,
|
||||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
|
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||||
kFFHiddenDim, num_tokens);
|
model_dim, kFFHiddenDim, num_tokens);
|
||||||
|
|
||||||
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
|
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||||
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
|
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
|
||||||
|
|
||||||
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
MatMulVJPT(weights.gating_einsum_w.Packed(),
|
||||||
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
|
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
|
||||||
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
|
grad.gating_einsum_w.Packed(),
|
||||||
|
backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
|
||||||
num_tokens);
|
num_tokens);
|
||||||
|
|
||||||
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
RMSNormVJPT(
|
||||||
backward.bf_pre_ffw_rms_out.data(),
|
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
|
||||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(),
|
||||||
model_dim, num_tokens);
|
backward.attention_out.Packed(), model_dim, num_tokens);
|
||||||
|
|
||||||
AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
|
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);
|
||||||
|
|
||||||
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
MultiHeadMatMulVJPT(
|
||||||
backward.attention_out.data(),
|
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
|
||||||
grad.attn_vec_einsum_w.data(), backward.att_out.data(),
|
backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(),
|
||||||
kHeads, model_dim, qkv_dim, num_tokens);
|
backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens);
|
||||||
|
|
||||||
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
|
MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(),
|
||||||
backward.att_out.data(), backward.qkv.data(),
|
backward.att_out.Packed(), backward.qkv.Packed(),
|
||||||
backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
backward.att.Packed(), num_tokens, kHeads, qkv_dim,
|
||||||
|
|
||||||
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
|
|
||||||
seq_len);
|
seq_len);
|
||||||
|
|
||||||
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
|
MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens,
|
||||||
backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
kHeads, seq_len);
|
||||||
|
|
||||||
|
MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(),
|
||||||
|
backward.qkv.Packed(), num_tokens, kHeads, qkv_dim,
|
||||||
|
seq_len);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
|
||||||
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
|
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
for (int pos = 0; pos < num_tokens; ++pos) {
|
||||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
|
||||||
for (size_t h = 0; h <= kHeads; ++h) {
|
for (size_t h = 0; h <= kHeads; ++h) {
|
||||||
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
|
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
|
||||||
backward.qkv.data(), grad.qkv_einsum_w.data(),
|
backward.qkv.Packed(), grad.qkv_einsum_w.Packed(),
|
||||||
backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
|
backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim,
|
||||||
num_tokens);
|
model_dim, num_tokens);
|
||||||
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
|
||||||
backward.pre_att_rms_out.data(),
|
backward.pre_att_rms_out.Packed(),
|
||||||
grad.pre_attention_norm_scale.data(), backward.input.data(),
|
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
|
||||||
model_dim, num_tokens);
|
model_dim, num_tokens);
|
||||||
|
|
||||||
AddFromT(backward.attention_out.data(), backward.input.data(),
|
AddFromT(backward.attention_out.Packed(), backward.input.Packed(),
|
||||||
num_tokens * model_dim);
|
num_tokens * model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -307,41 +310,42 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
||||||
const std::vector<int> tokens = prompt.tokens;
|
const std::vector<int> tokens = prompt.tokens;
|
||||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
|
||||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
|
||||||
vocab_size);
|
vocab_size);
|
||||||
|
|
||||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
|
SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size,
|
||||||
num_tokens);
|
num_tokens);
|
||||||
|
|
||||||
if (config.final_cap > 0.0f) {
|
if (config.final_cap > 0.0f) {
|
||||||
for (size_t i = 0; i < num_tokens; ++i) {
|
for (size_t i = 0; i < num_tokens; ++i) {
|
||||||
SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
|
SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size,
|
||||||
backward.logits.data() + i * vocab_size, vocab_size);
|
backward.logits.Packed() + i * vocab_size, vocab_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulVJPT(
|
MatMulVJPT(weights.embedder_input_embedding.Packed(),
|
||||||
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
|
forward.final_norm_output.Packed(), backward.logits.Packed(),
|
||||||
backward.logits.data(), grad.embedder_input_embedding.data(),
|
grad.embedder_input_embedding.Packed(),
|
||||||
backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
|
backward.final_norm_output.Packed(), vocab_size, model_dim,
|
||||||
|
num_tokens);
|
||||||
|
|
||||||
RMSNormVJPT(weights.final_norm_scale.data(),
|
RMSNormVJPT(
|
||||||
forward.final_layer_output.data(),
|
weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(),
|
||||||
backward.final_norm_output.data(), grad.final_norm_scale.data(),
|
backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(),
|
||||||
backward.final_layer_output.data(), model_dim, num_tokens);
|
backward.final_layer_output.Packed(), model_dim, num_tokens);
|
||||||
|
|
||||||
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
|
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
|
||||||
T* next_layer_grad = layer + 1 < layers
|
T* next_layer_grad = layer + 1 < layers
|
||||||
? backward.layers[layer + 1].input.data()
|
? backward.layers[layer + 1].input.Packed()
|
||||||
: backward.final_layer_output.data();
|
: backward.final_layer_output.Packed();
|
||||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||||
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
||||||
}
|
}
|
||||||
|
|
||||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
|
InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens,
|
||||||
kEmbScaling, backward.layers[0].input.data(),
|
kEmbScaling, backward.layers[0].input.Packed(),
|
||||||
grad.embedder_input_embedding.data(), model_dim);
|
grad.embedder_input_embedding.Packed(), model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -31,9 +31,9 @@
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "backprop/sampler.h"
|
#include "backprop/sampler.h"
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "util/mat.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -44,14 +44,14 @@ TEST(BackPropTest, MatMulVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> weights("weights", kRows, kCols);
|
auto weights = MakePacked<T>("weights", kRows, kCols);
|
||||||
MatStorageT<T> x("x", kTokens, kCols);
|
auto x = MakePacked<T>("x", kTokens, kCols);
|
||||||
MatStorageT<T> grad("grad", kRows, kCols);
|
auto grad = MakePacked<T>("grad", kRows, kCols);
|
||||||
MatStorageT<T> dx("dx", kTokens, kCols);
|
auto dx = MakePacked<T>("dx", kTokens, kCols);
|
||||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
|
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||||
MatStorageT<TC> c_x("c_x", kTokens, kCols);
|
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||||
MatStorageT<T> dy("dy", kTokens, kRows);
|
auto dy = MakePacked<T>("dy", kTokens, kRows);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||||
|
|
@ -60,12 +60,13 @@ TEST(BackPropTest, MatMulVJP) {
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
|
||||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
kTokens);
|
||||||
|
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||||
};
|
};
|
||||||
grad.ZeroInit();
|
ZeroInit(grad);
|
||||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||||
kRows, kCols, kTokens);
|
dx.Packed(), kRows, kCols, kTokens);
|
||||||
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
|
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -79,14 +80,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> weights("weights", kRows, kCols * kHeads);
|
auto weights = MakePacked<T>("weights", kRows, kCols * kHeads);
|
||||||
MatStorageT<T> x("x", kTokens, kCols * kHeads);
|
auto x = MakePacked<T>("x", kTokens, kCols * kHeads);
|
||||||
MatStorageT<T> grad("grad", kRows, kCols * kHeads);
|
auto grad = MakePacked<T>("grad", kRows, kCols * kHeads);
|
||||||
MatStorageT<T> dx("dx", kTokens, kCols * kHeads);
|
auto dx = MakePacked<T>("dx", kTokens, kCols * kHeads);
|
||||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
|
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||||
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
|
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||||
MatStorageT<T> dy("dy", kTokens, kRows);
|
auto dy = MakePacked<T>("dy", kTokens, kRows);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||||
|
|
@ -95,13 +96,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
|
||||||
kCols, kTokens);
|
kRows, kCols, kTokens);
|
||||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||||
};
|
};
|
||||||
grad.ZeroInit();
|
ZeroInit(grad);
|
||||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
|
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||||
dx.data(), kHeads, kRows, kCols, kTokens);
|
grad.Packed(), dx.Packed(), kHeads, kRows, kCols,
|
||||||
|
kTokens);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -113,14 +115,14 @@ TEST(BackPropTest, RMSNormVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> weights("weights", N, 1);
|
auto weights = MakePacked<T>("weights", N, 1);
|
||||||
MatStorageT<T> grad("grad", N, 1);
|
auto grad = MakePacked<T>("grad", N, 1);
|
||||||
MatStorageT<T> x("x", K, N);
|
auto x = MakePacked<T>("x", K, N);
|
||||||
MatStorageT<T> dx("dx", K, N);
|
auto dx = MakePacked<T>("dx", K, N);
|
||||||
MatStorageT<T> dy("dy", K, N);
|
auto dy = MakePacked<T>("dy", K, N);
|
||||||
MatStorageT<TC> c_weights("c_weights", N, 1);
|
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||||
MatStorageT<TC> c_x("c_x", K, N);
|
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||||
MatStorageT<TC> c_y("c_y", K, N);
|
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||||
|
|
@ -129,12 +131,12 @@ TEST(BackPropTest, RMSNormVJP) {
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||||
return DotT(dy.data(), c_y.data(), K * N);
|
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||||
};
|
};
|
||||||
grad.ZeroInit();
|
ZeroInit(grad);
|
||||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||||
N, K);
|
dx.Packed(), N, K);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -145,24 +147,24 @@ TEST(BackPropTest, SoftmaxVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> x("x", N, 1);
|
auto x = MakePacked<T>("x", N, 1);
|
||||||
MatStorageT<T> dx("dx", N, 1);
|
auto dx = MakePacked<T>("dx", N, 1);
|
||||||
MatStorageT<T> dy("dy", N, 1);
|
auto dy = MakePacked<T>("dy", N, 1);
|
||||||
MatStorageT<TC> c_x("c_x", N, 1);
|
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||||
MatStorageT<TC> c_y("c_y", N, 1);
|
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
memcpy(c_y.data(), c_x.data(), c_x.SizeBytes());
|
CopyMat(c_x, c_y);
|
||||||
Softmax(c_y.data(), N);
|
Softmax(c_y.Packed(), N);
|
||||||
return DotT(dy.data(), c_y.data(), N);
|
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||||
};
|
};
|
||||||
Softmax(x.data(), N);
|
Softmax(x.Packed(), N);
|
||||||
memcpy(dx.data(), dy.data(), dx.SizeBytes());
|
CopyMat(dy, dx);
|
||||||
SoftmaxVJPT(x.data(), dx.data(), N);
|
SoftmaxVJPT(x.Packed(), dx.Packed(), N);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -175,26 +177,25 @@ TEST(BackPropTest, MaskedSoftmaxVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> x("x", N, 1);
|
auto x = MakePacked<T>("x", N, 1);
|
||||||
MatStorageT<T> dy("dy", N, 1);
|
auto dy = MakePacked<T>("dy", N, 1);
|
||||||
MatStorageT<T> dx("dx", N, 1);
|
auto dx = MakePacked<T>("dx", N, 1);
|
||||||
MatStorageT<TC> c_x("c_x", N, 1);
|
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||||
MatStorageT<TC> c_y("c_y", N, 1);
|
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||||
dx.ZeroInit();
|
ZeroInit(dx);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
memcpy(c_y.data(), c_x.data(),
|
CopyMat(c_x, c_y);
|
||||||
kTokens * kHeads * kSeqLen * sizeof(c_x.At(0)));
|
MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen);
|
||||||
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
|
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||||
return DotT(dy.data(), c_y.data(), N);
|
|
||||||
};
|
};
|
||||||
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
|
MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen);
|
||||||
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0)));
|
CopyMat(dy, dx);
|
||||||
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
|
MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen);
|
||||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -204,11 +205,11 @@ TEST(BackPropTest, SoftcapVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> x("x", N, 1);
|
auto x = MakePacked<T>("x", N, 1);
|
||||||
MatStorageT<T> dx("dx", N, 1);
|
auto dx = MakePacked<T>("dx", N, 1);
|
||||||
MatStorageT<T> dy("dy", N, 1);
|
auto dy = MakePacked<T>("dy", N, 1);
|
||||||
MatStorageT<TC> c_x("c_x", N, 1);
|
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||||
MatStorageT<TC> c_y("c_y", N, 1);
|
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||||
|
|
||||||
constexpr float kCap = 30.0f;
|
constexpr float kCap = 30.0f;
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
|
|
@ -216,13 +217,13 @@ TEST(BackPropTest, SoftcapVJP) {
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0)));
|
CopyMat(c_x, c_y);
|
||||||
Softcap(kCap, c_y.data(), N);
|
Softcap(kCap, c_y.Packed(), N);
|
||||||
return DotT(dy.data(), c_y.data(), N);
|
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||||
};
|
};
|
||||||
Softcap(kCap, x.data(), N);
|
Softcap(kCap, x.Packed(), N);
|
||||||
memcpy(dx.data(), dy.data(), dx.SizeBytes());
|
CopyMat(dy, dx);
|
||||||
SoftcapVJPT(kCap, x.data(), dx.data(), N);
|
SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -233,9 +234,9 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> x("x", K, V);
|
auto x = MakePacked<T>("x", K, V);
|
||||||
MatStorageT<T> dx("dx", K, V);
|
auto dx = MakePacked<T>("dx", K, V);
|
||||||
MatStorageT<TC> c_x("c_x", K, V);
|
auto c_x = MakePacked<TC>("c_x", K, V);
|
||||||
Prompt prompt;
|
Prompt prompt;
|
||||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||||
|
|
||||||
|
|
@ -243,13 +244,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
prompt.context_size = 1 + (iter % 6);
|
prompt.context_size = 1 + (iter % 6);
|
||||||
RandInit(x, 1.0 * (1 << iter), gen);
|
RandInit(x, 1.0 * (1 << iter), gen);
|
||||||
Softcap(kCap, x.data(), V * K);
|
Softcap(kCap, x.Packed(), V * K);
|
||||||
Softmax(x.data(), V, K);
|
Softmax(x.Packed(), V, K);
|
||||||
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); };
|
||||||
return CrossEntropyLoss(c_x.data(), prompt, V);
|
|
||||||
};
|
|
||||||
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -260,21 +259,21 @@ TEST(BackPropTest, GatedGeluVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> x("x", K, 2 * N);
|
auto x = MakePacked<T>("x", K, 2 * N);
|
||||||
MatStorageT<T> dx("dx", K, 2 * N);
|
auto dx = MakePacked<T>("dx", K, 2 * N);
|
||||||
MatStorageT<T> dy("dy", K, N);
|
auto dy = MakePacked<T>("dy", K, N);
|
||||||
MatStorageT<TC> c_x("c_x", K, 2 * N);
|
auto c_x = MakePacked<TC>("c_x", K, 2 * N);
|
||||||
MatStorageT<TC> c_y("c_y", K, N);
|
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0, gen);
|
RandInit(x, 1.0, gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
GatedGelu(c_x.data(), c_y.data(), N, K);
|
GatedGelu(c_x.Packed(), c_y.Packed(), N, K);
|
||||||
return DotT(dy.data(), c_y.data(), N * K);
|
return DotT(dy.Packed(), c_y.Packed(), N * K);
|
||||||
};
|
};
|
||||||
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
|
GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K);
|
||||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -289,25 +288,25 @@ TEST(BackPropTest, MaskedAttentionVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> x("x", kQKVSize, 1);
|
auto x = MakePacked<T>("x", kQKVSize, 1);
|
||||||
MatStorageT<T> dx("dx", kQKVSize, 1);
|
auto dx = MakePacked<T>("dx", kQKVSize, 1);
|
||||||
MatStorageT<T> dy("dy", kOutSize, 1);
|
auto dy = MakePacked<T>("dy", kOutSize, 1);
|
||||||
MatStorageT<TC> c_x("c_x", kQKVSize, 1);
|
auto c_x = MakePacked<TC>("c_x", kQKVSize, 1);
|
||||||
MatStorageT<TC> c_y("c_y", kOutSize, 1);
|
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
|
||||||
dx.ZeroInit();
|
ZeroInit(dx);
|
||||||
c_y.ZeroInit();
|
ZeroInit(c_y);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(x, 1.0, gen);
|
RandInit(x, 1.0, gen);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
|
MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim,
|
||||||
kSeqLen);
|
kSeqLen);
|
||||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
|
||||||
};
|
};
|
||||||
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
|
MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads,
|
||||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
kQKVDim, kSeqLen);
|
||||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -323,17 +322,17 @@ TEST(BackPropTest, MixByAttentionVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> qkv("qkv", kQKVSize, 1);
|
auto qkv = MakePacked<T>("qkv", kQKVSize, 1);
|
||||||
MatStorageT<T> dqkv("dqkv", kQKVSize, 1);
|
auto dqkv = MakePacked<T>("dqkv", kQKVSize, 1);
|
||||||
MatStorageT<T> attn("attn", kAttnSize, 1);
|
auto attn = MakePacked<T>("attn", kAttnSize, 1);
|
||||||
MatStorageT<T> dattn("dattn", kAttnSize, 1);
|
auto dattn = MakePacked<T>("dattn", kAttnSize, 1);
|
||||||
MatStorageT<T> dy("dy", kOutSize, 1);
|
auto dy = MakePacked<T>("dy", kOutSize, 1);
|
||||||
MatStorageT<TC> c_qkv("c_qkv", kQKVSize, 1);
|
auto c_qkv = MakePacked<TC>("c_qkv", kQKVSize, 1);
|
||||||
MatStorageT<TC> c_attn("c_attn", kAttnSize, 1);
|
auto c_attn = MakePacked<TC>("c_attn", kAttnSize, 1);
|
||||||
MatStorageT<TC> c_y("c_y", kOutSize, 1);
|
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
|
||||||
dqkv.ZeroInit();
|
ZeroInit(dqkv);
|
||||||
dattn.ZeroInit();
|
ZeroInit(dattn);
|
||||||
c_y.ZeroInit();
|
ZeroInit(c_y);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(qkv, 1.0, gen);
|
RandInit(qkv, 1.0, gen);
|
||||||
|
|
@ -342,12 +341,12 @@ TEST(BackPropTest, MixByAttentionVJP) {
|
||||||
Complexify(attn, c_attn);
|
Complexify(attn, c_attn);
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
|
MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens,
|
||||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
kHeads, kQKVDim, kSeqLen);
|
||||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
|
||||||
};
|
};
|
||||||
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
|
MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(),
|
||||||
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
|
dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||||
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
|
||||||
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
|
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -360,11 +359,11 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
MatStorageT<T> weights("weights", kVocabSize, kModelDim);
|
auto weights = MakePacked<T>("weights", kVocabSize, kModelDim);
|
||||||
MatStorageT<T> grad("grad", kVocabSize, kModelDim);
|
auto grad = MakePacked<T>("grad", kVocabSize, kModelDim);
|
||||||
MatStorageT<T> dy("dy", kSeqLen, kModelDim);
|
auto dy = MakePacked<T>("dy", kSeqLen, kModelDim);
|
||||||
MatStorageT<TC> c_weights("c_weights", kVocabSize, kModelDim);
|
auto c_weights = MakePacked<TC>("c_weights", kVocabSize, kModelDim);
|
||||||
MatStorageT<TC> c_y("c_y", kSeqLen, kModelDim);
|
auto c_y = MakePacked<TC>("c_y", kSeqLen, kModelDim);
|
||||||
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
|
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
|
||||||
size_t num_tokens = tokens.size() - 1;
|
size_t num_tokens = tokens.size() - 1;
|
||||||
|
|
||||||
|
|
@ -373,12 +372,13 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
||||||
RandInit(dy, 1.0, gen);
|
RandInit(dy, 1.0, gen);
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
|
InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(),
|
||||||
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
|
|
||||||
};
|
|
||||||
grad.ZeroInit();
|
|
||||||
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
|
|
||||||
kModelDim);
|
kModelDim);
|
||||||
|
return DotT(dy.Packed(), c_y.Packed(), num_tokens * kModelDim);
|
||||||
|
};
|
||||||
|
ZeroInit(grad);
|
||||||
|
InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(),
|
||||||
|
grad.Packed(), kModelDim);
|
||||||
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
|
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -410,8 +410,7 @@ TEST(BackPropTest, LayerVJP) {
|
||||||
using T = double;
|
using T = double;
|
||||||
using TC = std::complex<T>;
|
using TC = std::complex<T>;
|
||||||
ModelConfig config = TestConfig();
|
ModelConfig config = TestConfig();
|
||||||
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
|
const TensorIndex tensor_index = TensorIndexLLM(config, size_t{0});
|
||||||
/*reshape_att=*/false);
|
|
||||||
const size_t kOutputSize = config.seq_len * config.model_dim;
|
const size_t kOutputSize = config.seq_len * config.model_dim;
|
||||||
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
|
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
|
||||||
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
|
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
|
||||||
|
|
@ -419,15 +418,15 @@ TEST(BackPropTest, LayerVJP) {
|
||||||
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
||||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
|
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
|
||||||
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
||||||
MatStorageT<T> y("y", kOutputSize, 1);
|
auto y = MakePacked<T>("y", kOutputSize, 1);
|
||||||
MatStorageT<T> dy("dy", kOutputSize, 1);
|
auto dy = MakePacked<T>("dy", kOutputSize, 1);
|
||||||
MatStorageT<TC> c_y("c_y", kOutputSize, 1);
|
auto c_y = MakePacked<TC>("c_y", kOutputSize, 1);
|
||||||
const size_t num_tokens = 3;
|
const size_t num_tokens = 3;
|
||||||
std::vector<MatStorage> layer_storage;
|
std::vector<MatOwner> layer_storage;
|
||||||
weights.Allocate(layer_storage);
|
weights.Allocate(layer_storage);
|
||||||
grad.Allocate(layer_storage);
|
grad.Allocate(layer_storage);
|
||||||
c_weights.Allocate(layer_storage);
|
c_weights.Allocate(layer_storage);
|
||||||
backward.input.ZeroInit();
|
ZeroInit(backward.input);
|
||||||
|
|
||||||
for (size_t iter = 0; iter < 10; ++iter) {
|
for (size_t iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0, gen);
|
RandInit(weights, 1.0, gen);
|
||||||
|
|
@ -436,12 +435,12 @@ TEST(BackPropTest, LayerVJP) {
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(forward.input, c_forward.input);
|
Complexify(forward.input, c_forward.input);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
|
ApplyLayer(c_weights, c_forward, num_tokens, c_y.Packed());
|
||||||
return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
|
return DotT(dy.Packed(), c_y.Packed(), num_tokens * config.model_dim);
|
||||||
};
|
};
|
||||||
grad.ZeroInit(/*layer_idx=*/0);
|
grad.ZeroInit(/*layer_idx=*/0);
|
||||||
ApplyLayer(weights, forward, num_tokens, y.data());
|
ApplyLayer(weights, forward, num_tokens, y.Packed());
|
||||||
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
|
LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens);
|
||||||
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
|
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
|
||||||
__LINE__);
|
__LINE__);
|
||||||
TestGradient(grad, c_weights, func, 1e-11);
|
TestGradient(grad, c_weights, func, 1e-11);
|
||||||
|
|
|
||||||
|
|
@ -33,8 +33,10 @@
|
||||||
#include "backprop/test_util.h"
|
#include "backprop/test_util.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "ops/ops.h"
|
#include "ops/ops.h"
|
||||||
#include "util/threading.h"
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
|
|
@ -46,33 +48,45 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "backprop/backward-inl.h"
|
#include "backprop/backward-inl.h"
|
||||||
#include "backprop/forward-inl.h"
|
#include "backprop/forward-inl.h"
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/ops-inl.h"
|
||||||
#include "util/allocator.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
|
hwy::ThreadPool& ThreadHostileGetPool() {
|
||||||
|
// Assume this is only called at the top level, i.e. not in a thread. Then we
|
||||||
|
// can safely call `SetArgs` only once, because it would assert otherwise.
|
||||||
|
// This is preferable to calling `ThreadHostileInvalidate`, because we would
|
||||||
|
// repeat the topology initialization for every test.
|
||||||
|
if (!ThreadingContext2::IsInitialized()) {
|
||||||
|
gcpp::ThreadingArgs threading_args;
|
||||||
|
threading_args.max_packages = 1;
|
||||||
|
threading_args.max_clusters = 8;
|
||||||
|
threading_args.pin = Tristate::kFalse;
|
||||||
|
ThreadingContext2::SetArgs(threading_args);
|
||||||
|
}
|
||||||
|
return ThreadingContext2::Get().pools.Pool();
|
||||||
|
}
|
||||||
|
|
||||||
void TestMatMulVJP() {
|
void TestMatMulVJP() {
|
||||||
static const size_t kRows = 8;
|
static const size_t kRows = 8;
|
||||||
static const size_t kCols = 64;
|
static const size_t kCols = 64;
|
||||||
static const size_t kTokens = 5;
|
static const size_t kTokens = 5;
|
||||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
|
||||||
Allocator::Init(topology);
|
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
MatStorageT<float> weights("weights", kRows, kCols);
|
auto weights = MakePacked<float>("weights", kRows, kCols);
|
||||||
MatStorageT<float> x("x", kTokens, kCols);
|
auto x = MakePacked<float>("x", kTokens, kCols);
|
||||||
MatStorageT<float> dy("dy", kTokens, kRows);
|
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||||
MatStorageT<float> grad("grad", kRows, kCols);
|
auto grad = MakePacked<float>("grad", kRows, kCols);
|
||||||
MatStorageT<float> dx("dx", kTokens, kCols);
|
auto dx = MakePacked<float>("dx", kTokens, kCols);
|
||||||
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols);
|
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols);
|
||||||
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols);
|
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols);
|
||||||
using TC = std::complex<double>;
|
using TC = std::complex<double>;
|
||||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
|
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||||
MatStorageT<TC> c_x("c_x", kTokens, kCols);
|
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||||
|
|
@ -81,19 +95,20 @@ void TestMatMulVJP() {
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
|
||||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
kTokens);
|
||||||
|
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||||
};
|
};
|
||||||
|
|
||||||
grad.ZeroInit();
|
ZeroInit(grad);
|
||||||
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
|
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
|
||||||
grad.data(), dx.data(), pools.Pool());
|
grad.Packed(), dx.Packed(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
|
|
||||||
grad_scalar.ZeroInit();
|
ZeroInit(grad_scalar);
|
||||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||||
dx_scalar.data(), kRows, kCols, kTokens);
|
dx_scalar.Packed(), kRows, kCols, kTokens);
|
||||||
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
|
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
|
||||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -104,21 +119,19 @@ void TestMultiHeadMatMulVJP() {
|
||||||
static const size_t kCols = 16;
|
static const size_t kCols = 16;
|
||||||
static const size_t kHeads = 4;
|
static const size_t kHeads = 4;
|
||||||
static const size_t kTokens = 3;
|
static const size_t kTokens = 3;
|
||||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||||
Allocator::Init(topology);
|
|
||||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
|
auto weights = MakePacked<float>("weights", kRows, kCols * kHeads);
|
||||||
MatStorageT<float> x("x", kTokens, kCols * kHeads);
|
auto x = MakePacked<float>("x", kTokens, kCols * kHeads);
|
||||||
MatStorageT<float> grad("grad", kRows, kCols * kHeads);
|
auto grad = MakePacked<float>("grad", kRows, kCols * kHeads);
|
||||||
MatStorageT<float> dx("dx", kTokens, kCols * kHeads);
|
auto dx = MakePacked<float>("dx", kTokens, kCols * kHeads);
|
||||||
MatStorageT<float> dy("dy", kTokens, kRows);
|
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||||
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols * kHeads);
|
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols * kHeads);
|
||||||
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols * kHeads);
|
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols * kHeads);
|
||||||
using TC = std::complex<double>;
|
using TC = std::complex<double>;
|
||||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
|
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||||
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
|
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||||
|
|
@ -127,20 +140,21 @@ void TestMultiHeadMatMulVJP() {
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
|
||||||
kCols, kTokens);
|
kRows, kCols, kTokens);
|
||||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||||
};
|
};
|
||||||
|
|
||||||
grad.ZeroInit();
|
ZeroInit(grad);
|
||||||
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
|
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
|
||||||
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
|
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
|
|
||||||
grad_scalar.ZeroInit();
|
ZeroInit(grad_scalar);
|
||||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||||
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows,
|
||||||
|
kCols, kTokens);
|
||||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -149,21 +163,19 @@ void TestMultiHeadMatMulVJP() {
|
||||||
void TestRMSNormVJP() {
|
void TestRMSNormVJP() {
|
||||||
static const size_t K = 2;
|
static const size_t K = 2;
|
||||||
static const size_t N = 64;
|
static const size_t N = 64;
|
||||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||||
Allocator::Init(topology);
|
|
||||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
MatStorageT<float> weights("weights", N, 1);
|
auto weights = MakePacked<float>("weights", N, 1);
|
||||||
MatStorageT<float> x("x", K, N);
|
auto x = MakePacked<float>("x", K, N);
|
||||||
MatStorageT<float> grad("grad", N, 1);
|
auto grad = MakePacked<float>("grad", N, 1);
|
||||||
MatStorageT<float> dx("dx", K, N);
|
auto dx = MakePacked<float>("dx", K, N);
|
||||||
MatStorageT<float> dy("dy", K, N);
|
auto dy = MakePacked<float>("dy", K, N);
|
||||||
MatStorageT<float> grad_scalar("grad_scalar", N, 1);
|
auto grad_scalar = MakePacked<float>("grad_scalar", N, 1);
|
||||||
MatStorageT<float> dx_scalar("dx_scalar", K, N);
|
auto dx_scalar = MakePacked<float>("dx_scalar", K, N);
|
||||||
using TC = std::complex<double>;
|
using TC = std::complex<double>;
|
||||||
MatStorageT<TC> c_weights("c_weights", N, 1);
|
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||||
MatStorageT<TC> c_x("c_x", K, N);
|
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||||
MatStorageT<TC> c_y("c_y", K, N);
|
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||||
|
|
||||||
for (int iter = 0; iter < 10; ++iter) {
|
for (int iter = 0; iter < 10; ++iter) {
|
||||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||||
|
|
@ -172,19 +184,19 @@ void TestRMSNormVJP() {
|
||||||
Complexify(weights, c_weights);
|
Complexify(weights, c_weights);
|
||||||
Complexify(x, c_x);
|
Complexify(x, c_x);
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||||
return DotT(dy.data(), c_y.data(), K * N);
|
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||||
};
|
};
|
||||||
|
|
||||||
grad.ZeroInit();
|
ZeroInit(grad);
|
||||||
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
|
||||||
dx.data(), pools.Pool());
|
dx.Packed(), pool);
|
||||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||||
|
|
||||||
grad_scalar.ZeroInit();
|
ZeroInit(grad_scalar);
|
||||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||||
dx_scalar.data(), N, K);
|
dx_scalar.Packed(), N, K);
|
||||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
||||||
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
|
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
|
||||||
}
|
}
|
||||||
|
|
@ -215,9 +227,7 @@ static ModelConfig TestConfig() {
|
||||||
|
|
||||||
void TestEndToEnd() {
|
void TestEndToEnd() {
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
|
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||||
Allocator::Init(topology);
|
|
||||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
|
||||||
ModelConfig config = TestConfig();
|
ModelConfig config = TestConfig();
|
||||||
WeightsWrapper<float> weights(config);
|
WeightsWrapper<float> weights(config);
|
||||||
WeightsWrapper<float> grad(config);
|
WeightsWrapper<float> grad(config);
|
||||||
|
|
@ -232,7 +242,7 @@ void TestEndToEnd() {
|
||||||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||||
|
|
||||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||||
config.layer_configs[0].qkv_dim,
|
ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim,
|
||||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||||
for (const Prompt& prompt : batch) {
|
for (const Prompt& prompt : batch) {
|
||||||
ReverseSequenceSampler::LogPrompt(prompt);
|
ReverseSequenceSampler::LogPrompt(prompt);
|
||||||
|
|
@ -242,13 +252,13 @@ void TestEndToEnd() {
|
||||||
|
|
||||||
float loss1 = CrossEntropyLossForwardPass(
|
float loss1 = CrossEntropyLossForwardPass(
|
||||||
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
||||||
inv_timescale, pools.Pool());
|
inv_timescale, pool);
|
||||||
|
|
||||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||||
|
|
||||||
grad.ZeroInit();
|
grad.ZeroInit();
|
||||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||||
backward, inv_timescale, pools.Pool());
|
backward, inv_timescale, pool);
|
||||||
|
|
||||||
Complexify(weights.get(), c_weights.get());
|
Complexify(weights.get(), c_weights.get());
|
||||||
auto func = [&]() {
|
auto func = [&]() {
|
||||||
|
|
|
||||||
|
|
@ -20,7 +20,7 @@
|
||||||
|
|
||||||
#include <complex>
|
#include <complex>
|
||||||
|
|
||||||
#include "compression/compress.h" // MatStorageT
|
#include "util/mat.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -60,7 +60,9 @@ void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
|
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
|
||||||
MulByConstAndAddT(c, x.data(), out.data(), x.NumElements());
|
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||||
|
MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
|
|
@ -50,16 +51,17 @@ HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
template <typename ArrayT>
|
template <typename T>
|
||||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
void InputEmbedding(const MatPtrT<T>& weights, const std::vector<int>& prompt,
|
||||||
const float scaling, float* HWY_RESTRICT output,
|
const float scaling, float* HWY_RESTRICT output,
|
||||||
size_t model_dim, size_t vocab_size) {
|
size_t model_dim, size_t vocab_size) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
HWY_ASSERT(!prompt.empty());
|
HWY_ASSERT(!prompt.empty());
|
||||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||||
int token = prompt[pos];
|
int token = prompt[pos];
|
||||||
DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size),
|
const auto span = weights.Span();
|
||||||
token * model_dim, output + pos * model_dim,
|
HWY_ASSERT(span.num == model_dim * vocab_size);
|
||||||
|
DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim,
|
||||||
model_dim);
|
model_dim);
|
||||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
|
|
@ -109,27 +111,27 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||||
|
|
||||||
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
|
ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(),
|
||||||
activations.input.data(), model_dim, num_tokens,
|
activations.input.Packed(), model_dim, num_tokens,
|
||||||
activations.pre_att_rms_out.data(), pool);
|
activations.pre_att_rms_out.Packed(), pool);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
|
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
|
||||||
activations.pre_att_rms_out.data() + pos * model_dim,
|
activations.pre_att_rms_out.Packed() + pos * model_dim,
|
||||||
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||||
}
|
}
|
||||||
const size_t num_tasks = kHeads * num_tokens;
|
const size_t num_tasks = kHeads * num_tokens;
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
float* HWY_RESTRICT k =
|
float* HWY_RESTRICT k =
|
||||||
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||||
Rope(k, kQKVDim, inv_timescale.Const(), pos);
|
Rope(k, kQKVDim, inv_timescale.Const(), pos);
|
||||||
}
|
}
|
||||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t pos = task / kHeads;
|
const size_t pos = task / kHeads;
|
||||||
float* HWY_RESTRICT q =
|
float* HWY_RESTRICT q =
|
||||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||||
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
||||||
MulByConst(query_scale, q, kQKVDim);
|
MulByConst(query_scale, q, kQKVDim);
|
||||||
});
|
});
|
||||||
|
|
@ -138,12 +140,12 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t pos = task / kHeads;
|
const size_t pos = task / kHeads;
|
||||||
const float* HWY_RESTRICT q =
|
const float* HWY_RESTRICT q =
|
||||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const float* HWY_RESTRICT k2 =
|
const float* HWY_RESTRICT k2 =
|
||||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2] = score;
|
head_att[pos2] = score;
|
||||||
}
|
}
|
||||||
|
|
@ -153,7 +155,7 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t pos = task / kHeads;
|
const size_t pos = task / kHeads;
|
||||||
float* HWY_RESTRICT head_att =
|
float* HWY_RESTRICT head_att =
|
||||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||||
Softmax(head_att, pos + 1);
|
Softmax(head_att, pos + 1);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -161,51 +163,51 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||||
const size_t head = task % kHeads;
|
const size_t head = task % kHeads;
|
||||||
const size_t pos = task / kHeads;
|
const size_t pos = task / kHeads;
|
||||||
const float* HWY_RESTRICT head_att =
|
const float* HWY_RESTRICT head_att =
|
||||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||||
float* HWY_RESTRICT att_out =
|
float* HWY_RESTRICT att_out =
|
||||||
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
|
activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
float* HWY_RESTRICT v2 =
|
float* HWY_RESTRICT v2 = activations.qkv.Packed() +
|
||||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
(pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
activations.attention_out.ZeroInit();
|
ZeroInit(activations.attention_out);
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
for (size_t head = 0; head < kHeads; ++head) {
|
for (size_t head = 0; head < kHeads; ++head) {
|
||||||
MatVec(
|
MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
||||||
weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
|
||||||
kQKVDim,
|
kQKVDim,
|
||||||
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
activations.att_out.Packed() + pos * kHeads * kQKVDim +
|
||||||
activations.att_post1.data() + pos * model_dim, pool);
|
head * kQKVDim,
|
||||||
AddFrom(activations.att_post1.data() + pos * model_dim,
|
activations.att_post1.Packed() + pos * model_dim, pool);
|
||||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
AddFrom(activations.att_post1.Packed() + pos * model_dim,
|
||||||
|
activations.attention_out.Packed() + pos * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
AddFrom(activations.input.data() + pos * model_dim,
|
AddFrom(activations.input.Packed() + pos * model_dim,
|
||||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
activations.attention_out.Packed() + pos * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
|
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
|
||||||
activations.attention_out.data(), model_dim, num_tokens,
|
activations.attention_out.Packed(), model_dim, num_tokens,
|
||||||
activations.bf_pre_ffw_rms_out.data(), pool);
|
activations.bf_pre_ffw_rms_out.Packed(), pool);
|
||||||
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
||||||
activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
|
activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim,
|
||||||
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
|
||||||
}
|
}
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||||
const float* HWY_RESTRICT out =
|
const float* HWY_RESTRICT out =
|
||||||
activations.ffw_hidden.data() + hidden_offset;
|
activations.ffw_hidden.Packed() + hidden_offset;
|
||||||
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
||||||
float* HWY_RESTRICT out_gated =
|
float* HWY_RESTRICT out_gated =
|
||||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim;
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using DF = hn::ScalableTag<float>;
|
using DF = hn::ScalableTag<float>;
|
||||||
DF df;
|
DF df;
|
||||||
|
|
@ -217,11 +219,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
||||||
}
|
}
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
|
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
|
||||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim,
|
||||||
output + pos * model_dim, pool);
|
output + pos * model_dim, pool);
|
||||||
}
|
}
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
AddFrom(activations.attention_out.data() + pos * model_dim,
|
AddFrom(activations.attention_out.Packed() + pos * model_dim,
|
||||||
output + pos * model_dim, model_dim);
|
output + pos * model_dim, model_dim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -247,44 +249,43 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
||||||
const size_t num_tokens = prompt.size() - 1;
|
const size_t num_tokens = prompt.size() - 1;
|
||||||
|
|
||||||
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
|
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
|
||||||
forward.layers[0].input.data(), model_dim, vocab_size);
|
forward.layers[0].input.Packed(), model_dim, vocab_size);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
||||||
auto type = config.layer_configs[layer].type;
|
auto type = config.layer_configs[layer].type;
|
||||||
// TODO(szabadka) Implement Griffin layer.
|
// TODO(szabadka) Implement Griffin layer.
|
||||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||||
float* HWY_RESTRICT output = layer + 1 < layers
|
float* HWY_RESTRICT output = layer + 1 < layers
|
||||||
? forward.layers[layer + 1].input.data()
|
? forward.layers[layer + 1].input.Packed()
|
||||||
: forward.final_layer_output.data();
|
: forward.final_layer_output.Packed();
|
||||||
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
|
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
|
||||||
num_tokens, output, inv_timescale, pool);
|
num_tokens, output, inv_timescale, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
ApplyRMSNorm(weights.final_norm_scale.data(),
|
ApplyRMSNorm(weights.final_norm_scale.Packed(),
|
||||||
forward.final_layer_output.data(), model_dim, num_tokens,
|
forward.final_layer_output.Packed(), model_dim, num_tokens,
|
||||||
forward.final_norm_output.data(), pool);
|
forward.final_norm_output.Packed(), pool);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
|
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
|
||||||
forward.final_norm_output.data() + pos * model_dim,
|
forward.final_norm_output.Packed() + pos * model_dim,
|
||||||
forward.logits.data() + pos * vocab_size, pool);
|
forward.logits.Packed() + pos * vocab_size, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (config.final_cap > 0.0f) {
|
if (config.final_cap > 0.0f) {
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
LogitsSoftCap(config.final_cap,
|
||||||
vocab_size);
|
forward.logits.Packed() + pos * vocab_size, vocab_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
CopyMat(forward.logits, forward.probs);
|
||||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
|
Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
|
return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size,
|
||||||
vocab_size, pool);
|
vocab_size, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,9 +17,7 @@
|
||||||
|
|
||||||
#include "backprop/activations.h"
|
#include "backprop/activations.h"
|
||||||
#include "backprop/prompt.h"
|
#include "backprop/prompt.h"
|
||||||
#include "gemma/common.h"
|
#include "util/mat.h"
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||||
|
|
|
||||||
|
|
@ -180,54 +180,59 @@ void ApplyLayer(const LayerWeightsPtrs<T>& weights,
|
||||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||||
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
|
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
|
||||||
|
|
||||||
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
|
RMSNormT(weights.pre_attention_norm_scale.Packed(),
|
||||||
activations.pre_att_rms_out.data(), model_dim, num_tokens);
|
activations.input.Packed(), activations.pre_att_rms_out.Packed(),
|
||||||
|
model_dim, num_tokens);
|
||||||
|
|
||||||
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
|
MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(),
|
||||||
activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
|
activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim,
|
||||||
|
num_tokens);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
|
||||||
for (size_t h = 0; h <= heads; ++h) {
|
for (size_t h = 0; h <= heads; ++h) {
|
||||||
Rope(qkv + h * qkv_dim, qkv_dim, pos);
|
Rope(qkv + h * qkv_dim, qkv_dim, pos);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
|
||||||
MulByConstT(query_scale, qkv, heads * qkv_dim);
|
MulByConstT(query_scale, qkv, heads * qkv_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
|
MaskedAttention(activations.qkv.Packed(), activations.att.Packed(),
|
||||||
heads, qkv_dim, seq_len);
|
num_tokens, heads, qkv_dim, seq_len);
|
||||||
|
|
||||||
MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
|
MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len);
|
||||||
|
|
||||||
MixByAttention(activations.qkv.data(), activations.att.data(),
|
MixByAttention(activations.qkv.Packed(), activations.att.Packed(),
|
||||||
activations.att_out.data(), num_tokens, heads, qkv_dim,
|
activations.att_out.Packed(), num_tokens, heads, qkv_dim,
|
||||||
seq_len);
|
seq_len);
|
||||||
|
|
||||||
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
|
MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(),
|
||||||
activations.attention_out.data(), heads, model_dim, qkv_dim,
|
activations.att_out.Packed(),
|
||||||
|
activations.attention_out.Packed(), heads, model_dim, qkv_dim,
|
||||||
num_tokens);
|
num_tokens);
|
||||||
|
|
||||||
AddFromT(activations.input.data(), activations.attention_out.data(),
|
AddFromT(activations.input.Packed(), activations.attention_out.Packed(),
|
||||||
num_tokens * model_dim);
|
num_tokens * model_dim);
|
||||||
|
|
||||||
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
|
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
|
||||||
activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
|
activations.attention_out.Packed(),
|
||||||
|
activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens);
|
||||||
|
|
||||||
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
|
MatMulT(weights.gating_einsum_w.Packed(),
|
||||||
activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
|
activations.bf_pre_ffw_rms_out.Packed(),
|
||||||
|
activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim,
|
||||||
num_tokens);
|
num_tokens);
|
||||||
|
|
||||||
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
|
GatedGelu(activations.ffw_hidden.Packed(),
|
||||||
ff_hidden_dim, num_tokens);
|
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);
|
||||||
|
|
||||||
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
|
MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(),
|
||||||
model_dim, ff_hidden_dim, num_tokens);
|
output, model_dim, ff_hidden_dim, num_tokens);
|
||||||
|
|
||||||
AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
|
AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
|
|
@ -258,35 +263,35 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
||||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||||
|
|
||||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||||
InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
|
InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling,
|
||||||
forward.layers[0].input.data(), model_dim);
|
forward.layers[0].input.Packed(), model_dim);
|
||||||
|
|
||||||
for (size_t layer = 0; layer < layers; ++layer) {
|
for (size_t layer = 0; layer < layers; ++layer) {
|
||||||
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
|
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed()
|
||||||
: forward.final_layer_output.data();
|
: forward.final_layer_output.Packed();
|
||||||
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
||||||
output);
|
output);
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
RMSNormT(weights.final_norm_scale.Packed(),
|
||||||
forward.final_norm_output.data(), model_dim, num_tokens);
|
forward.final_layer_output.Packed(),
|
||||||
|
forward.final_norm_output.Packed(), model_dim, num_tokens);
|
||||||
|
|
||||||
MatMulT(weights.embedder_input_embedding.data(),
|
MatMulT(weights.embedder_input_embedding.Packed(),
|
||||||
forward.final_norm_output.data(), forward.logits.data(), vocab_size,
|
forward.final_norm_output.Packed(), forward.logits.Packed(),
|
||||||
model_dim, num_tokens);
|
vocab_size, model_dim, num_tokens);
|
||||||
|
|
||||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||||
if (config.final_cap > 0.0f) {
|
if (config.final_cap > 0.0f) {
|
||||||
Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size,
|
||||||
vocab_size);
|
vocab_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
memcpy(forward.probs.data(), forward.logits.data(),
|
CopyMat(forward.logits, forward.probs);
|
||||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
Softmax(forward.probs.Packed(), vocab_size, num_tokens);
|
||||||
Softmax(forward.probs.data(), vocab_size, num_tokens);
|
|
||||||
|
|
||||||
return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
|
return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -41,11 +41,14 @@
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
TEST(OptimizeTest, GradientDescent) {
|
TEST(OptimizeTest, GradientDescent) {
|
||||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
|
gcpp::ThreadingArgs threading_args;
|
||||||
Allocator::Init(topology);
|
threading_args.max_packages = 1;
|
||||||
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
threading_args.max_clusters = 1;
|
||||||
MatMulEnv env(topology, pools);
|
threading_args.pin = Tristate::kFalse;
|
||||||
hwy::ThreadPool& pool = pools.Pool();
|
ThreadingContext2::SetArgs(threading_args);
|
||||||
|
MatMulEnv env(ThreadingContext2::Get());
|
||||||
|
const Allocator2& allocator = env.ctx.allocator;
|
||||||
|
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||||
std::mt19937 gen(42);
|
std::mt19937 gen(42);
|
||||||
|
|
||||||
const ModelInfo info = {
|
const ModelInfo info = {
|
||||||
|
|
@ -64,7 +67,7 @@ TEST(OptimizeTest, GradientDescent) {
|
||||||
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
||||||
|
|
||||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||||
config.layer_configs[0].qkv_dim,
|
allocator, config.layer_configs[0].qkv_dim,
|
||||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||||
|
|
||||||
Gemma gemma(GemmaTokenizer(), info, env);
|
Gemma gemma(GemmaTokenizer(), info, env);
|
||||||
|
|
|
||||||
|
|
@ -18,9 +18,9 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "gemma/common.h"
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -39,11 +39,11 @@ class AdamUpdater {
|
||||||
|
|
||||||
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
|
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
|
||||||
MatPtr& grad_m, MatPtr& grad_v) {
|
MatPtr& grad_m, MatPtr& grad_v) {
|
||||||
const float* HWY_RESTRICT g = grad.data<float>();
|
const float* HWY_RESTRICT g = grad.RowT<float>(0);
|
||||||
float* HWY_RESTRICT w = weights.data<float>();
|
float* HWY_RESTRICT w = weights.RowT<float>(0);
|
||||||
float* HWY_RESTRICT m = grad_m.data<float>();
|
float* HWY_RESTRICT m = grad_m.RowT<float>(0);
|
||||||
float* HWY_RESTRICT v = grad_v.data<float>();
|
float* HWY_RESTRICT v = grad_v.RowT<float>(0);
|
||||||
for (size_t i = 0; i < grad.NumElements(); ++i) {
|
for (size_t i = 0; i < grad.Extents().Area(); ++i) {
|
||||||
m[i] *= beta1_;
|
m[i] *= beta1_;
|
||||||
m[i] += cbeta1_ * g[i];
|
m[i] += cbeta1_ * g[i];
|
||||||
v[i] *= beta2_;
|
v[i] *= beta2_;
|
||||||
|
|
|
||||||
|
|
@ -24,21 +24,13 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
|
||||||
std::normal_distribution<T> dist(0.0, stddev);
|
|
||||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
|
||||||
x.At(i) = dist(gen);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: make a member of Layer<T>.
|
// TODO: make a member of Layer<T>.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||||
|
|
@ -62,8 +54,12 @@ void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||||
|
|
||||||
template <typename T, typename U>
|
template <typename T, typename U>
|
||||||
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
||||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||||
c_x.At(i) = std::complex<U>(x.At(i), 0.0);
|
const T* row = x.Row(r);
|
||||||
|
std::complex<U>* c_row = c_x.Row(r);
|
||||||
|
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||||
|
c_row[c] = std::complex<U>(row[c], 0.0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -87,14 +83,14 @@ void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Somewhat duplicates ModelWeightsStorage, but that has neither double nor
|
// Somewhat duplicates WeightsOwner, but that has neither double nor
|
||||||
// complex types allowed and it would cause code bloat to add them there.
|
// complex types allowed and it would cause code bloat to add them there.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
class WeightsWrapper {
|
class WeightsWrapper {
|
||||||
public:
|
public:
|
||||||
explicit WeightsWrapper(const ModelConfig& config)
|
explicit WeightsWrapper(const ModelConfig& config)
|
||||||
: pool_(0), weights_(config) {
|
: pool_(0), weights_(config) {
|
||||||
weights_.Allocate(data_, pool_);
|
weights_.Allocate(owners_, pool_);
|
||||||
}
|
}
|
||||||
|
|
||||||
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
||||||
|
|
@ -106,7 +102,7 @@ class WeightsWrapper {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
hwy::ThreadPool pool_;
|
hwy::ThreadPool pool_;
|
||||||
std::vector<MatStorage> data_;
|
std::vector<MatOwner> owners_;
|
||||||
ModelWeightsPtrs<T> weights_;
|
ModelWeightsPtrs<T> weights_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -116,13 +112,18 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||||
double sum0 = 0;
|
double sum0 = 0;
|
||||||
double sum1 = 0;
|
double sum1 = 0;
|
||||||
double sum01 = 0;
|
double sum01 = 0;
|
||||||
for (size_t i = 0; i < actual.NumElements(); ++i) {
|
for (size_t r = 0; r < actual.Rows(); ++r) {
|
||||||
sum0 += actual.At(i) * actual.At(i);
|
const T* actual_row = actual.Row(r);
|
||||||
sum1 += expected.At(i) * expected.At(i);
|
const U* expected_row = expected.Row(r);
|
||||||
sum01 += actual.At(i) * expected.At(i);
|
for (size_t c = 0; c < actual.Cols(); ++c) {
|
||||||
ASSERT_NEAR(actual.At(i), expected.At(i),
|
sum0 += actual_row[c] * actual_row[c];
|
||||||
std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err))
|
sum1 += expected_row[c] * expected_row[c];
|
||||||
<< "line: " << line << " dim=" << expected.NumElements() << " i=" << i;
|
sum01 += actual_row[c] * expected_row[c];
|
||||||
|
ASSERT_NEAR(
|
||||||
|
actual_row[c], expected_row[c],
|
||||||
|
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
|
||||||
|
<< "line: " << line << " r " << r << " c " << c;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if (sum0 > 1e-40) {
|
if (sum0 > 1e-40) {
|
||||||
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
||||||
|
|
@ -148,15 +149,19 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
||||||
template <typename FUNC, typename T, typename U>
|
template <typename FUNC, typename T, typename U>
|
||||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
||||||
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
|
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
|
||||||
MatStorageT<T> exp_grad("exp_grad", x.Rows(), x.Cols());
|
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
|
||||||
const U inv_step = 1.0 / step;
|
const U inv_step = 1.0 / step;
|
||||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||||
const U x0 = std::real(x.At(i));
|
std::complex<U>* x_row = x.Row(r);
|
||||||
|
T* exp_row = exp_grad.Row(r);
|
||||||
|
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||||
|
const U x0 = std::real(x_row[c]);
|
||||||
const std::complex<U> x1 = std::complex<U>(x0, step);
|
const std::complex<U> x1 = std::complex<U>(x0, step);
|
||||||
x.At(i) = x1;
|
x_row[c] = x1;
|
||||||
const std::complex<U> f1 = func();
|
const std::complex<U> f1 = func();
|
||||||
exp_grad.At(i) = std::imag(f1) * inv_step;
|
exp_row[c] = std::imag(f1) * inv_step;
|
||||||
x.At(i) = x0;
|
x_row[c] = x0;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
|
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -146,6 +146,7 @@ cc_library(
|
||||||
":distortion",
|
":distortion",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -209,6 +210,7 @@ cc_library(
|
||||||
"//:allocator",
|
"//:allocator",
|
||||||
"//:basics",
|
"//:basics",
|
||||||
"//:common",
|
"//:common",
|
||||||
|
"//:mat",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
"@highway//:stats",
|
"@highway//:stats",
|
||||||
|
|
@ -252,22 +254,6 @@ cc_library(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
cc_binary(
|
|
||||||
name = "compress_weights",
|
|
||||||
srcs = ["compress_weights.cc"],
|
|
||||||
deps = [
|
|
||||||
":compress",
|
|
||||||
":io",
|
|
||||||
"//:allocator",
|
|
||||||
"//:args",
|
|
||||||
"//:common",
|
|
||||||
"//:tokenizer",
|
|
||||||
"//:weights",
|
|
||||||
"@highway//:hwy",
|
|
||||||
"@highway//:thread_pool",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
cc_binary(
|
cc_binary(
|
||||||
name = "blob_compare",
|
name = "blob_compare",
|
||||||
srcs = ["blob_compare.cc"],
|
srcs = ["blob_compare.cc"],
|
||||||
|
|
@ -277,9 +263,11 @@ cc_binary(
|
||||||
"//:allocator",
|
"//:allocator",
|
||||||
"//:basics",
|
"//:basics",
|
||||||
"//:threading",
|
"//:threading",
|
||||||
|
"//:threading_context",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:hwy_test_util",
|
"@highway//:hwy_test_util",
|
||||||
"@highway//:nanobenchmark",
|
"@highway//:nanobenchmark",
|
||||||
|
"@highway//:thread_pool",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -287,7 +275,6 @@ cc_binary(
|
||||||
name = "migrate_weights",
|
name = "migrate_weights",
|
||||||
srcs = ["migrate_weights.cc"],
|
srcs = ["migrate_weights.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//:app",
|
|
||||||
"//:args",
|
"//:args",
|
||||||
"//:benchmark_helper",
|
"//:benchmark_helper",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
|
|
|
||||||
|
|
@ -25,8 +25,10 @@
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h" // IndexRange
|
#include "util/basics.h" // IndexRange
|
||||||
#include "util/threading.h"
|
#include "util/threading.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -202,15 +204,13 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
||||||
if (!CompareKeys(reader1, reader2)) return;
|
if (!CompareKeys(reader1, reader2)) return;
|
||||||
|
|
||||||
// Single allocation, avoid initializing the memory.
|
// Single allocation, avoid initializing the memory.
|
||||||
BoundedTopology topology;
|
|
||||||
Allocator::Init(topology);
|
|
||||||
NestedPools pools(topology);
|
|
||||||
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
|
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
|
||||||
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
|
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
|
||||||
size_t pos = 0;
|
size_t pos = 0;
|
||||||
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
|
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
|
||||||
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
|
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
|
||||||
|
|
||||||
|
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||||
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
|
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
|
||||||
|
|
||||||
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@
|
||||||
#include "compression/compress.h" // IWYU pragma: export
|
#include "compression/compress.h" // IWYU pragma: export
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -379,7 +380,7 @@ struct CompressTraits<SfpStream> {
|
||||||
using Packed = SfpStream;
|
using Packed = SfpStream;
|
||||||
|
|
||||||
// Callers are responsible for scaling `raw` such that its magnitudes do not
|
// Callers are responsible for scaling `raw` such that its magnitudes do not
|
||||||
// exceed `SfpStream::kMax`. See CompressedArray::scale().
|
// exceed `SfpStream::kMax`. See CompressedArray::Scale().
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
template <class DF, HWY_IF_F32_D(DF)>
|
||||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||||
size_t num, CompressPerThread& tls,
|
size_t num, CompressPerThread& tls,
|
||||||
|
|
@ -522,8 +523,7 @@ HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
|
||||||
CompressWorkingSet& work,
|
CompressWorkingSet& work,
|
||||||
MatStorageT<Packed>& compressed,
|
MatStorageT<Packed>& compressed,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
Compress(raw, num, work,
|
Compress(raw, num, work, compressed.Span(),
|
||||||
MakeSpan(compressed.data(), compressed.NumElements()),
|
|
||||||
/*packed_ofs=*/0, pool);
|
/*packed_ofs=*/0, pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -717,11 +717,9 @@ class Compressor {
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||||
const float* HWY_RESTRICT weights) {
|
const float* HWY_RESTRICT weights) {
|
||||||
size_t num_weights = compressed->NumElements();
|
size_t num_weights = compressed->Extents().Area();
|
||||||
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
|
if (num_weights == 0 || weights == nullptr || !compressed->HasPtr()) return;
|
||||||
return;
|
PackedSpan<Packed> packed = compressed->Span();
|
||||||
size_t num_compressed = compressed->NumElements();
|
|
||||||
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
|
||||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||||
num_weights / (1000 * 1000));
|
num_weights / (1000 * 1000));
|
||||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
MatPtr::~MatPtr() {}
|
// TODO: move ScaleWeights here.
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,8 @@
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "hwy/per_target.h"
|
#include "util/mat.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#if COMPRESS_STATS
|
#if COMPRESS_STATS
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
#include "hwy/stats.h"
|
#include "hwy/stats.h"
|
||||||
|
|
@ -49,322 +50,6 @@
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Base class for rank-1 or 2 tensors (vector or matrix).
|
|
||||||
// Supports both dynamic and compile-time sizing.
|
|
||||||
// Holds metadata and a non-owning pointer to the data, owned by the derived
|
|
||||||
// MatStorageT class.
|
|
||||||
// This class also provides easy conversion from/to a table of contents for a
|
|
||||||
// BlobStore file, and a templated (compile-time) accessor for a 2-d array of
|
|
||||||
// fixed inner dimension and type.
|
|
||||||
// It is designed to be put in a vector, and has default copy and operator=, so
|
|
||||||
// it is easy to read/write a blob_store file.
|
|
||||||
class MatPtr : public IFields {
|
|
||||||
public:
|
|
||||||
// Full constructor for dynamic sizing.
|
|
||||||
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
|
||||||
size_t cols)
|
|
||||||
: name_(name),
|
|
||||||
type_(type),
|
|
||||||
element_size_(element_size),
|
|
||||||
num_elements_(rows * cols),
|
|
||||||
rows_(rows),
|
|
||||||
cols_(cols),
|
|
||||||
ptr_(nullptr) {
|
|
||||||
stride_ = cols;
|
|
||||||
}
|
|
||||||
// Default is to leave all fields default-initialized.
|
|
||||||
MatPtr() = default;
|
|
||||||
virtual ~MatPtr();
|
|
||||||
|
|
||||||
// Compatibility interface for CompressedArray.
|
|
||||||
// TODO: remove.
|
|
||||||
template <typename T>
|
|
||||||
T* data() {
|
|
||||||
return HWY_RCAST_ALIGNED(T*, ptr_);
|
|
||||||
}
|
|
||||||
template <typename T>
|
|
||||||
const T* data() const {
|
|
||||||
return HWY_RCAST_ALIGNED(const T*, ptr_);
|
|
||||||
}
|
|
||||||
|
|
||||||
const void* Ptr() const { return ptr_; }
|
|
||||||
void* Ptr() { return ptr_; }
|
|
||||||
// Sets the pointer from another MatPtr.
|
|
||||||
void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; }
|
|
||||||
|
|
||||||
// Copying allowed as the metadata is small.
|
|
||||||
MatPtr(const MatPtr& other) = default;
|
|
||||||
MatPtr& operator=(const MatPtr& other) = default;
|
|
||||||
|
|
||||||
// Returns the name of the blob.
|
|
||||||
const char* Name() const override { return name_.c_str(); }
|
|
||||||
void SetName(const std::string& name) { name_ = name; }
|
|
||||||
|
|
||||||
// Returns the type of the blob.
|
|
||||||
Type GetType() const { return type_; }
|
|
||||||
|
|
||||||
// Returns the size of each element in bytes.
|
|
||||||
size_t ElementSize() const { return element_size_; }
|
|
||||||
|
|
||||||
// Returns the number of elements in the array.
|
|
||||||
size_t NumElements() const { return num_elements_; }
|
|
||||||
|
|
||||||
// Returns the number of bytes in the array.
|
|
||||||
size_t SizeBytes() const {
|
|
||||||
if (this->GetType() == TypeEnum<NuqStream>()) {
|
|
||||||
return NuqStream::PackedEnd(num_elements_);
|
|
||||||
}
|
|
||||||
return num_elements_ * element_size_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns the number of rows in the 2-d array (outer dimension).
|
|
||||||
size_t Rows() const { return rows_; }
|
|
||||||
|
|
||||||
// Returns the number of columns in the 2-d array (inner dimension).
|
|
||||||
size_t Cols() const { return cols_; }
|
|
||||||
|
|
||||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
|
||||||
|
|
||||||
// Currently same as cols, but may differ in the future. This is the offset by
|
|
||||||
// which to advance pointers to the next row.
|
|
||||||
size_t Stride() const { return stride_; }
|
|
||||||
|
|
||||||
// Decoded elements should be multiplied by this to restore their original
|
|
||||||
// range. This is required because SfpStream can only encode a limited range
|
|
||||||
// of magnitudes.
|
|
||||||
float scale() const { return scale_; }
|
|
||||||
void set_scale(float scale) { scale_ = scale; }
|
|
||||||
|
|
||||||
std::string LayerName(int layer) const {
|
|
||||||
std::string name = name_ + std::to_string(layer);
|
|
||||||
HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t));
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets all data to zero.
|
|
||||||
void ZeroInit() {
|
|
||||||
if (ptr_ == nullptr)
|
|
||||||
HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str());
|
|
||||||
hwy::ZeroBytes(ptr_, SizeBytes());
|
|
||||||
}
|
|
||||||
|
|
||||||
void VisitFields(IFieldsVisitor& visitor) override {
|
|
||||||
visitor(name_);
|
|
||||||
visitor(type_);
|
|
||||||
visitor(element_size_);
|
|
||||||
visitor(num_elements_);
|
|
||||||
visitor(rows_);
|
|
||||||
visitor(cols_);
|
|
||||||
visitor(scale_);
|
|
||||||
visitor(stride_);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls func on the upcasted type. Since MatPtr by design is not templated,
|
|
||||||
// here we provide a way to get to the derived type, provided that `Type()`
|
|
||||||
// is one of the strings returned by `TypeName()`.
|
|
||||||
template <class FuncT, typename... TArgs>
|
|
||||||
decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args);
|
|
||||||
|
|
||||||
protected:
|
|
||||||
// Arbitrary name for the array of preferably <= 16 characters.
|
|
||||||
std::string name_;
|
|
||||||
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
|
||||||
Type type_;
|
|
||||||
// sizeof(T)
|
|
||||||
uint32_t element_size_ = 0;
|
|
||||||
// Number of elements in the array.
|
|
||||||
uint32_t num_elements_ = 0; // In element_size units.
|
|
||||||
// Number of rows in the 2-d array (outer dimension).
|
|
||||||
uint32_t rows_ = 0;
|
|
||||||
// Number of columns in the 2-d array (inner dimension).
|
|
||||||
uint32_t cols_ = 0;
|
|
||||||
// Scaling to apply to each element.
|
|
||||||
float scale_ = 1.0f;
|
|
||||||
// Aligned data array. This is always a borrowed pointer. It should never be
|
|
||||||
// freed. The underlying memory is owned by a subclass or some external class
|
|
||||||
// and must outlive this object.
|
|
||||||
void* ptr_ = nullptr;
|
|
||||||
|
|
||||||
uint32_t stride_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
|
||||||
// Use this class as a function argument where the type needs to be known.
|
|
||||||
// Use MatPtr where the type does not need to be known.
|
|
||||||
template <typename MatT>
|
|
||||||
class MatPtrT : public MatPtr {
|
|
||||||
public:
|
|
||||||
// Full constructor for dynamic sizing.
|
|
||||||
MatPtrT(const std::string& name, size_t rows, size_t cols)
|
|
||||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
|
||||||
// Construction from TensorIndex entry to remove duplication of sizes.
|
|
||||||
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
|
|
||||||
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
|
||||||
MatPtrT(const std::string& name, const TensorInfo* tensor)
|
|
||||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
|
||||||
if (tensor == nullptr) {
|
|
||||||
cols_ = 0;
|
|
||||||
rows_ = 0;
|
|
||||||
} else {
|
|
||||||
cols_ = tensor->shape.back();
|
|
||||||
rows_ = 1;
|
|
||||||
if (tensor->cols_take_extra_dims) {
|
|
||||||
// The columns eat the extra dimensions.
|
|
||||||
rows_ = tensor->shape[0];
|
|
||||||
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
|
||||||
cols_ *= tensor->shape[i];
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// The rows eat the extra dimensions.
|
|
||||||
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
|
||||||
rows_ *= tensor->shape[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
stride_ = cols_;
|
|
||||||
num_elements_ = rows_ * cols_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Copying allowed as the metadata is small.
|
|
||||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
|
||||||
MatPtrT& operator=(const MatPtr& other) {
|
|
||||||
MatPtr::operator=(other);
|
|
||||||
return *this;
|
|
||||||
}
|
|
||||||
MatPtrT(const MatPtrT& other) = default;
|
|
||||||
MatPtrT& operator=(const MatPtrT& other) = default;
|
|
||||||
|
|
||||||
std::string CacheName(int layer = -1, char separator = ' ',
|
|
||||||
int index = -1) const {
|
|
||||||
// Already used/retired: s, S, n, 1
|
|
||||||
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
|
|
||||||
: hwy::IsSame<MatT, BF16>() ? 'B'
|
|
||||||
: hwy::IsSame<MatT, SfpStream>() ? '$'
|
|
||||||
: hwy::IsSame<MatT, NuqStream>() ? '2'
|
|
||||||
: '?';
|
|
||||||
std::string name = std::string(1, prefix) + name_;
|
|
||||||
if (layer >= 0 || index >= 0) {
|
|
||||||
name += '_';
|
|
||||||
if (layer >= 0) name += std::to_string(layer);
|
|
||||||
if (index >= 0) {
|
|
||||||
name += separator + std::to_string(index);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return name;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Sets the number of elements in the array. For use when the number of
|
|
||||||
// elements is != rows * cols ONLY.
|
|
||||||
void SetNumElements(size_t num_elements) {
|
|
||||||
num_elements_ = CompressedArrayElements<MatT>(num_elements);
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2-d Accessor for a specific type but with a dynamic inner dimension.
|
|
||||||
template <typename T = MatT>
|
|
||||||
const T& At(size_t row, size_t col) const {
|
|
||||||
size_t index = row * cols_ + col;
|
|
||||||
HWY_DASSERT(index < num_elements_);
|
|
||||||
return HWY_RCAST_ALIGNED(const T*, ptr_)[index];
|
|
||||||
}
|
|
||||||
|
|
||||||
// 1-d Accessor for a specific type.
|
|
||||||
// TODO: replace this with a Foreach(), or at least a ForEachRow().
|
|
||||||
const MatT& At(size_t index) const {
|
|
||||||
HWY_DASSERT(index < num_elements_);
|
|
||||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index];
|
|
||||||
}
|
|
||||||
MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; }
|
|
||||||
|
|
||||||
// Compatibility interface for CompressedArray.
|
|
||||||
// TODO: remove
|
|
||||||
template <typename T = MatT>
|
|
||||||
T* data() {
|
|
||||||
return HWY_RCAST_ALIGNED(T*, ptr_);
|
|
||||||
}
|
|
||||||
template <typename T = MatT>
|
|
||||||
const T* data() const {
|
|
||||||
return HWY_RCAST_ALIGNED(const T*, ptr_);
|
|
||||||
}
|
|
||||||
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
|
|
||||||
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
|
|
||||||
// A scale of 0 indicates that the scale has likely never been set, so is
|
|
||||||
// "implicitly 1".
|
|
||||||
const MatT* data_scale1() const {
|
|
||||||
HWY_ASSERT(scale() == 1.f);
|
|
||||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
template <class FuncT, typename... TArgs>
|
|
||||||
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
|
||||||
if (type_ == TypeEnum<float>()) {
|
|
||||||
return func(dynamic_cast<MatPtrT<float>*>(this),
|
|
||||||
std::forward<TArgs>(args)...);
|
|
||||||
} else if (type_ == TypeEnum<BF16>()) {
|
|
||||||
return func(dynamic_cast<MatPtrT<BF16>*>(this),
|
|
||||||
std::forward<TArgs>(args)...);
|
|
||||||
} else if (type_ == TypeEnum<SfpStream>()) {
|
|
||||||
return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
|
|
||||||
std::forward<TArgs>(args)...);
|
|
||||||
} else if (type_ == TypeEnum<NuqStream>()) {
|
|
||||||
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
|
|
||||||
std::forward<TArgs>(args)...);
|
|
||||||
} else {
|
|
||||||
HWY_ABORT("Type %d unknown.", type_);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// MatStorageT adds the actual data storage to MatPtrT.
|
|
||||||
// TODO: use Extents2D instead of rows and cols.
|
|
||||||
template <typename MatT>
|
|
||||||
class MatStorageT : public MatPtrT<MatT> {
|
|
||||||
public:
|
|
||||||
// Full constructor for dynamic sizing.
|
|
||||||
MatStorageT(const std::string& name, size_t rows, size_t cols)
|
|
||||||
: MatPtrT<MatT>(name, rows, cols) {
|
|
||||||
Allocate();
|
|
||||||
}
|
|
||||||
// Can copy the metadata, from a MatPtr, and allocate later.
|
|
||||||
MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {}
|
|
||||||
~MatStorageT() = default;
|
|
||||||
|
|
||||||
// Move-only because this contains a unique_ptr.
|
|
||||||
MatStorageT(const MatStorageT& other) = delete;
|
|
||||||
MatStorageT& operator=(const MatStorageT& other) = delete;
|
|
||||||
MatStorageT(MatStorageT&& other) = default;
|
|
||||||
MatStorageT& operator=(MatStorageT&& other) = default;
|
|
||||||
|
|
||||||
// Allocate the memory and copy the pointer to the MatPtr.
|
|
||||||
// num_elements is in elements. In the default (zero) case, it is computed
|
|
||||||
// from the current num_elements_ which was set by the constructor from the
|
|
||||||
// rows and cols.
|
|
||||||
void Allocate(size_t num_elements = 0) {
|
|
||||||
if (num_elements == 0) {
|
|
||||||
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
|
|
||||||
} else {
|
|
||||||
this->num_elements_ = num_elements;
|
|
||||||
}
|
|
||||||
// Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
|
|
||||||
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
|
|
||||||
const size_t padding = hwy::VectorBytes();
|
|
||||||
data_ = Allocator::Alloc<MatT>(num_elements + padding);
|
|
||||||
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
|
|
||||||
this->ptr_ = data_.get();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Zeros the content.
|
|
||||||
void ZeroInit() {
|
|
||||||
HWY_ASSERT(data_ != nullptr);
|
|
||||||
hwy::ZeroBytes(data_.get(), this->SizeBytes());
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
AlignedPtr<MatT> data_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// MatStorage allows heterogeneous tensors to be stored in a single vector.
|
|
||||||
using MatStorage = MatStorageT<hwy::uint128_t>;
|
|
||||||
|
|
||||||
// Table of contents for a blob store file. Full metadata, but not actual data.
|
// Table of contents for a blob store file. Full metadata, but not actual data.
|
||||||
class BlobToc {
|
class BlobToc {
|
||||||
public:
|
public:
|
||||||
|
|
@ -389,7 +74,7 @@ class BlobToc {
|
||||||
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
||||||
prev_consumed = consumed;
|
prev_consumed = consumed;
|
||||||
consumed = result.pos;
|
consumed = result.pos;
|
||||||
if (blob.NumElements() > 0) {
|
if (!blob.IsEmpty()) {
|
||||||
AddToToc(blob);
|
AddToToc(blob);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -503,10 +188,11 @@ class WriteToBlobStore {
|
||||||
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||||
|
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
|
void operator()(MatPtrT<Packed>* compressed,
|
||||||
if (compressed->Ptr() == nullptr) return;
|
const char* decorated_name) const {
|
||||||
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
|
if (!compressed->HasPtr()) return;
|
||||||
compressed->SizeBytes());
|
writer_.Add(MakeKey(decorated_name), compressed->Packed(),
|
||||||
|
compressed->PackedBytes());
|
||||||
MatPtr renamed_tensor(*compressed);
|
MatPtr renamed_tensor(*compressed);
|
||||||
renamed_tensor.SetName(decorated_name);
|
renamed_tensor.SetName(decorated_name);
|
||||||
renamed_tensor.AppendTo(toc_);
|
renamed_tensor.AppendTo(toc_);
|
||||||
|
|
@ -519,9 +205,8 @@ class WriteToBlobStore {
|
||||||
|
|
||||||
void AddScales(const float* scales, size_t len) {
|
void AddScales(const float* scales, size_t len) {
|
||||||
if (len) {
|
if (len) {
|
||||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
|
||||||
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
writer_.Add(MakeKey(scales_ptr.Name()), scales, len * sizeof(scales[0]));
|
||||||
len * sizeof(scales[0]));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -554,9 +239,9 @@ class WriteToBlobStore {
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::vector<uint32_t> toc_;
|
mutable std::vector<uint32_t> toc_;
|
||||||
BlobWriter writer_;
|
mutable BlobWriter writer_;
|
||||||
std::vector<uint32_t> config_buffer_;
|
mutable std::vector<uint32_t> config_buffer_;
|
||||||
};
|
};
|
||||||
|
|
||||||
// Functor called for each tensor, which loads them and their scaling factors
|
// Functor called for each tensor, which loads them and their scaling factors
|
||||||
|
|
@ -613,6 +298,7 @@ class ReadFromBlobStore {
|
||||||
// Called for each tensor, enqueues read requests.
|
// Called for each tensor, enqueues read requests.
|
||||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
||||||
|
HWY_ASSERT(tensors[0]);
|
||||||
model_toc_.push_back(tensors[0]);
|
model_toc_.push_back(tensors[0]);
|
||||||
file_keys_.push_back(name);
|
file_keys_.push_back(name);
|
||||||
}
|
}
|
||||||
|
|
@ -622,15 +308,15 @@ class ReadFromBlobStore {
|
||||||
for (size_t i = 0; i < len; ++i) {
|
for (size_t i = 0; i < len; ++i) {
|
||||||
scales[i] = 1.0f;
|
scales[i] = 1.0f;
|
||||||
}
|
}
|
||||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
|
||||||
auto key = MakeKey(scales_ptr.CacheName().c_str());
|
auto key = MakeKey(scales_ptr.Name());
|
||||||
if (reader_.BlobSize(key) == 0) return 0;
|
if (reader_.BlobSize(key) == 0) return 0;
|
||||||
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
|
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether all tensors are successfully loaded from cache.
|
// Returns whether all tensors are successfully loaded from cache.
|
||||||
BlobError ReadAll(hwy::ThreadPool& pool,
|
BlobError ReadAll(hwy::ThreadPool& pool,
|
||||||
std::vector<MatStorage>& model_memory) {
|
std::vector<MatOwner>& model_memory) {
|
||||||
// reader_ invalid or any Enqueue failed
|
// reader_ invalid or any Enqueue failed
|
||||||
if (err_ != 0) return err_;
|
if (err_ != 0) return err_;
|
||||||
// Setup the model_memory.
|
// Setup the model_memory.
|
||||||
|
|
@ -650,26 +336,27 @@ class ReadFromBlobStore {
|
||||||
}
|
}
|
||||||
std::string name = blob->Name();
|
std::string name = blob->Name();
|
||||||
*blob = *toc_blob;
|
*blob = *toc_blob;
|
||||||
blob->SetName(name);
|
blob->SetName(name.c_str());
|
||||||
}
|
}
|
||||||
model_memory.emplace_back(*blob);
|
model_memory.push_back(MatOwner());
|
||||||
model_memory.back().SetName(file_key);
|
|
||||||
}
|
}
|
||||||
// Allocate in parallel using the pool.
|
// Allocate in parallel using the pool.
|
||||||
pool.Run(0, model_memory.size(),
|
pool.Run(0, model_memory.size(),
|
||||||
[this, &model_memory](uint64_t task, size_t /*thread*/) {
|
[this, &model_memory](uint64_t task, size_t /*thread*/) {
|
||||||
model_memory[task].Allocate();
|
model_memory[task].AllocateFor(*model_toc_[task],
|
||||||
model_toc_[task]->SetPtr(model_memory[task]);
|
MatPadding::kPacked);
|
||||||
});
|
});
|
||||||
// Enqueue the read requests.
|
// Enqueue the read requests.
|
||||||
for (auto& blob : model_memory) {
|
for (size_t b = 0; b < model_toc_.size(); ++b) {
|
||||||
err_ =
|
err_ = reader_.Enqueue(MakeKey(file_keys_[b].c_str()),
|
||||||
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
|
model_toc_[b]->RowT<uint8_t>(0),
|
||||||
|
model_toc_[b]->PackedBytes());
|
||||||
if (err_ != 0) {
|
if (err_ != 0) {
|
||||||
fprintf(stderr,
|
fprintf(
|
||||||
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
stderr,
|
||||||
blob.Name(), err_, blob.Rows(), blob.Cols(),
|
"Failed to read blob %s (error %d) of size %zu x %zu, type %d\n",
|
||||||
blob.ElementSize());
|
file_keys_[b].c_str(), err_, model_toc_[b]->Rows(),
|
||||||
|
model_toc_[b]->Cols(), static_cast<int>(model_toc_[b]->GetType()));
|
||||||
return err_;
|
return err_;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,286 +0,0 @@
|
||||||
// Copyright 2024 Google LLC
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
// you may not use this file except in compliance with the License.
|
|
||||||
// You may obtain a copy of the License at
|
|
||||||
//
|
|
||||||
// https://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
//
|
|
||||||
// Unless required by applicable law or agreed to in writing, software
|
|
||||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
// See the License for the specific language governing permissions and
|
|
||||||
// limitations under the License.
|
|
||||||
|
|
||||||
// Command line tool to create compressed weights.
|
|
||||||
|
|
||||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
|
||||||
// which we pass the filename via macro 'argument'.
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
|
||||||
#define HWY_TARGET_INCLUDE \
|
|
||||||
"compression/compress_weights.cc" // NOLINT
|
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
|
||||||
#include "hwy/highway.h"
|
|
||||||
// After highway.h
|
|
||||||
#include "compression/compress-inl.h"
|
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "gemma/tokenizer.h"
|
|
||||||
|
|
||||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
|
||||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
|
||||||
|
|
||||||
#include <stddef.h>
|
|
||||||
#include <stdio.h>
|
|
||||||
|
|
||||||
#include <algorithm> // std::clamp
|
|
||||||
#include <cstdlib>
|
|
||||||
#include <iostream>
|
|
||||||
#include <string>
|
|
||||||
#include <thread> // NOLINT
|
|
||||||
#include <vector>
|
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "compression/io.h" // Path
|
|
||||||
#include "compression/shared.h" // PromptWrapping
|
|
||||||
#include "gemma/common.h" // Model
|
|
||||||
#include "gemma/weights.h"
|
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "util/args.h"
|
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
struct Args : public ArgsBase<Args> {
|
|
||||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
|
||||||
|
|
||||||
void ChooseNumThreads() {
|
|
||||||
if (num_threads == kDefaultNumThreads) {
|
|
||||||
// This is a rough heuristic, replace with something better in the future.
|
|
||||||
num_threads = static_cast<size_t>(std::clamp(
|
|
||||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public:
|
|
||||||
Args(int argc, char* argv[]) {
|
|
||||||
InitAndParse(argc, argv);
|
|
||||||
ChooseNumThreads();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
|
||||||
const char* Validate() {
|
|
||||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
|
|
||||||
prompt_wrapping_)) {
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
|
||||||
return err;
|
|
||||||
}
|
|
||||||
if (weights.path.empty()) {
|
|
||||||
return "Missing --weights flag, a file for the uncompressed model.";
|
|
||||||
}
|
|
||||||
if (compressed_weights.path.empty()) {
|
|
||||||
return "Missing --compressed_weights flag, a file for the compressed "
|
|
||||||
"model.";
|
|
||||||
}
|
|
||||||
if (!weights.Exists()) {
|
|
||||||
return "Can't open file specified with --weights flag.";
|
|
||||||
}
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
Path weights; // uncompressed weights file location
|
|
||||||
Path compressed_weights; // compressed weights file location
|
|
||||||
std::string model_type_str;
|
|
||||||
std::string weight_type_str;
|
|
||||||
size_t num_threads;
|
|
||||||
// If non-empty, whether to include the config and TOC in the output file, as
|
|
||||||
// well as the tokenizer.
|
|
||||||
Path tokenizer;
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(const Visitor& visitor) {
|
|
||||||
visitor(weights, "weights", Path(),
|
|
||||||
"Path to model weights (.bin) file.\n"
|
|
||||||
" Required argument.");
|
|
||||||
visitor(model_type_str, "model", std::string(),
|
|
||||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
|
||||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
|
||||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
|
||||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
|
||||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
|
||||||
" Required argument.");
|
|
||||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
|
||||||
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
|
||||||
" Required argument.");
|
|
||||||
visitor(compressed_weights, "compressed_weights", Path(),
|
|
||||||
"Path name where compressed weights (.sbs) file will be written.\n"
|
|
||||||
" Required argument.");
|
|
||||||
visitor(num_threads, "num_threads",
|
|
||||||
kDefaultNumThreads, // see ChooseNumThreads
|
|
||||||
"Number of threads to use.\n Default = Estimate of the "
|
|
||||||
"number of supported concurrent threads.",
|
|
||||||
2);
|
|
||||||
visitor(tokenizer, "tokenizer", Path(),
|
|
||||||
"Path to tokenizer file. If given, the config and TOC are also "
|
|
||||||
"added to the output file.");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Uninitialized before Validate, must call after that.
|
|
||||||
gcpp::Model ModelType() const { return model_type_; }
|
|
||||||
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
|
|
||||||
gcpp::Type WeightType() const { return weight_type_; }
|
|
||||||
|
|
||||||
private:
|
|
||||||
Model model_type_;
|
|
||||||
PromptWrapping prompt_wrapping_;
|
|
||||||
Type weight_type_;
|
|
||||||
};
|
|
||||||
|
|
||||||
void ShowHelp(gcpp::Args& args) {
|
|
||||||
std::cerr
|
|
||||||
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
|
||||||
" --model <model type> --compressed_weights <output path>\n";
|
|
||||||
std::cerr << "\n*Arguments*\n\n";
|
|
||||||
args.Help();
|
|
||||||
std::cerr << "\n";
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
|
||||||
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
|
|
||||||
|
|
||||||
// SIMD code, compiled once per target.
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
|
||||||
namespace gcpp {
|
|
||||||
namespace HWY_NAMESPACE {
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void CompressWeights(const Path& weights_path,
|
|
||||||
const Path& compressed_weights_path, Model model_type,
|
|
||||||
Type weight_type, PromptWrapping wrapping,
|
|
||||||
const Path& tokenizer_path, hwy::ThreadPool& pool) {
|
|
||||||
if (!weights_path.Exists()) {
|
|
||||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
|
||||||
weights_path.path.c_str());
|
|
||||||
}
|
|
||||||
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
|
||||||
compressed_weights_path.path.c_str());
|
|
||||||
ModelConfig config = ConfigFromModel(model_type);
|
|
||||||
config.weight = weight_type;
|
|
||||||
config.wrapping = wrapping;
|
|
||||||
std::vector<MatStorage> model_storage;
|
|
||||||
ModelWeightsPtrs<T> c_weights(config);
|
|
||||||
c_weights.Allocate(model_storage, pool);
|
|
||||||
ModelWeightsPtrs<float> uc_weights(config);
|
|
||||||
uc_weights.Allocate(model_storage, pool);
|
|
||||||
// Get uncompressed weights, compress, and store.
|
|
||||||
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
|
|
||||||
if (fptr == nullptr) {
|
|
||||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
|
||||||
weights_path.path.c_str());
|
|
||||||
}
|
|
||||||
bool ok = true;
|
|
||||||
uint64_t total_size = 0;
|
|
||||||
ModelWeightsPtrs<float>::ForEachTensor(
|
|
||||||
{&uc_weights}, ForEachType::kLoadNoToc,
|
|
||||||
[&](const char* name, hwy::Span<MatPtr*> tensors) {
|
|
||||||
fprintf(stderr, "Loading Parameters (size %zu): %s\n",
|
|
||||||
tensors[0]->SizeBytes(), name);
|
|
||||||
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
|
||||||
total_size += tensors[0]->SizeBytes();
|
|
||||||
});
|
|
||||||
if (!tokenizer_path.path.empty()) {
|
|
||||||
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
|
|
||||||
}
|
|
||||||
const bool scale_for_compression = config.num_tensor_scales > 0;
|
|
||||||
std::vector<float> scales;
|
|
||||||
if (scale_for_compression) {
|
|
||||||
uc_weights.GetOrApplyScales(scales);
|
|
||||||
}
|
|
||||||
Compressor compressor(pool);
|
|
||||||
ModelWeightsPtrs<T>::ForEachTensor(
|
|
||||||
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
|
||||||
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
|
|
||||||
: ForEachType::kLoadWithToc,
|
|
||||||
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
|
||||||
tensors[1]->CallUpcasted(
|
|
||||||
compressor, name,
|
|
||||||
reinterpret_cast<const float*>(tensors[0]->Ptr()));
|
|
||||||
});
|
|
||||||
if (!tokenizer_path.path.empty()) {
|
|
||||||
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
|
|
||||||
compressor.AddTokenizer(tokenizer_proto);
|
|
||||||
} else {
|
|
||||||
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
|
||||||
}
|
|
||||||
compressor.WriteAll(compressed_weights_path,
|
|
||||||
tokenizer_path.path.empty() ? nullptr : &config);
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace HWY_NAMESPACE
|
|
||||||
} // namespace gcpp
|
|
||||||
HWY_AFTER_NAMESPACE();
|
|
||||||
|
|
||||||
#if HWY_ONCE
|
|
||||||
namespace gcpp {
|
|
||||||
|
|
||||||
void Run(Args& args) {
|
|
||||||
hwy::ThreadPool pool(args.num_threads);
|
|
||||||
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
|
|
||||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
|
||||||
}
|
|
||||||
const Model model_type = args.ModelType();
|
|
||||||
const Type weight_type = args.WeightType();
|
|
||||||
switch (weight_type) {
|
|
||||||
case Type::kF32:
|
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
|
||||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
|
||||||
args.PromptWrappingType(), args.tokenizer, pool);
|
|
||||||
break;
|
|
||||||
case Type::kBF16:
|
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
|
||||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
|
||||||
args.PromptWrappingType(), args.tokenizer, pool);
|
|
||||||
break;
|
|
||||||
case Type::kSFP:
|
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
|
||||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
|
||||||
args.PromptWrappingType(), args.tokenizer, pool);
|
|
||||||
break;
|
|
||||||
case Type::kNUQ:
|
|
||||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
|
||||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
|
||||||
args.PromptWrappingType(), args.tokenizer, pool);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace gcpp
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
|
||||||
gcpp::Args args(argc, argv);
|
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
|
||||||
gcpp::ShowHelp(args);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (const char* error = args.Validate()) {
|
|
||||||
gcpp::ShowHelp(args);
|
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::Run(args);
|
|
||||||
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
#endif // HWY_ONCE
|
|
||||||
|
|
@ -57,6 +57,6 @@ int main(int argc, char** argv) {
|
||||||
}
|
}
|
||||||
gcpp::GemmaEnv env(argc, argv);
|
gcpp::GemmaEnv env(argc, argv);
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
env.GetModel()->Save(args.output_weights, pool);
|
env.GetGemma()->Save(args.output_weights, pool);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,9 @@ cc_library(
|
||||||
deps = [
|
deps = [
|
||||||
"@abseil-cpp//absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
"//:common",
|
"//:common",
|
||||||
|
"//:mat",
|
||||||
"//:tokenizer",
|
"//:tokenizer",
|
||||||
|
"//:weights",
|
||||||
"//compression:compress",
|
"//compression:compress",
|
||||||
"//compression:io",
|
"//compression:io",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
|
|
@ -30,7 +32,6 @@ pybind_extension(
|
||||||
deps = [
|
deps = [
|
||||||
":compression_clif_aux",
|
":compression_clif_aux",
|
||||||
"@abseil-cpp//absl/types:span",
|
"@abseil-cpp//absl/types:span",
|
||||||
"//:common",
|
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,8 @@
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "gemma/weights.h"
|
||||||
|
#include "util/mat.h"
|
||||||
|
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE \
|
#define HWY_TARGET_INCLUDE \
|
||||||
|
|
@ -81,30 +82,23 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void AllocateAndCompress(const std::string& name,
|
void AllocateAndCompress(const std::string& name,
|
||||||
absl::Span<const float> weights) {
|
absl::Span<const float> weights) {
|
||||||
MatPtrT<Packed> storage(name, 1, weights.size());
|
MatPtrT<Packed> storage(name.c_str(), Extents2D(1, weights.size()));
|
||||||
model_memory_.push_back(storage);
|
model_memory_.push_back(MatOwner());
|
||||||
model_memory_.back().Allocate();
|
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
|
||||||
storage.SetPtr(model_memory_.back());
|
std::string decorated_name = CacheName(storage);
|
||||||
std::string decorated_name = storage.CacheName();
|
|
||||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||||
}
|
}
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
void AllocateWithShape(const std::string& name,
|
void AllocateWithShape(const std::string& name,
|
||||||
absl::Span<const float> weights,
|
absl::Span<const float> weights,
|
||||||
const TensorInfo& tensor_info, float scale) {
|
const TensorInfo& tensor_info, float scale) {
|
||||||
MatPtrT<Packed> storage(name, &tensor_info);
|
MatPtrT<Packed> storage(name.c_str(), &tensor_info);
|
||||||
storage.set_scale(scale);
|
storage.SetScale(scale);
|
||||||
|
|
||||||
// Don't reset num_elements for NUQ.
|
model_memory_.push_back(MatOwner());
|
||||||
if (!hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
|
||||||
storage.SetNumElements(CompressedArrayElements<Packed>(weights.size()));
|
|
||||||
}
|
|
||||||
|
|
||||||
model_memory_.push_back(storage);
|
|
||||||
if (mode_ == CompressorMode::kTEST_ONLY) return;
|
if (mode_ == CompressorMode::kTEST_ONLY) return;
|
||||||
model_memory_.back().Allocate();
|
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
|
||||||
storage.SetPtr(model_memory_.back());
|
std::string decorated_name = CacheName(storage);
|
||||||
std::string decorated_name = storage.CacheName();
|
|
||||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -176,7 +170,7 @@ class SbsWriterImpl : public WriterInterface {
|
||||||
hwy::ThreadPool pool_;
|
hwy::ThreadPool pool_;
|
||||||
Compressor compressor_;
|
Compressor compressor_;
|
||||||
CompressWorkingSet working_set_;
|
CompressWorkingSet working_set_;
|
||||||
std::vector<MatStorage> model_memory_;
|
std::vector<MatOwner> model_memory_;
|
||||||
std::vector<float> scales_;
|
std::vector<float> scales_;
|
||||||
CompressorMode mode_;
|
CompressorMode mode_;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -201,15 +201,24 @@ inline bool EnumValid(PromptWrapping type) {
|
||||||
|
|
||||||
// Tensor types for loading weights. Note that not all types are supported as
|
// Tensor types for loading weights. Note that not all types are supported as
|
||||||
// weights for a model, but can be used for other purposes, such as types for
|
// weights for a model, but can be used for other purposes, such as types for
|
||||||
// ModelWeightsPtrs. When adding a new type that is supported, also
|
// `WeightsPtrs`. When adding a new type that is supported, also
|
||||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||||
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
static constexpr const char* kTypeStrings[] = {
|
||||||
"nuq", "f64", "c64", "u128"};
|
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"};
|
||||||
|
static constexpr size_t kNumTypes =
|
||||||
|
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||||
|
static constexpr size_t kTypeBits[] = {0,
|
||||||
|
8 * sizeof(float),
|
||||||
|
8 * sizeof(BF16),
|
||||||
|
8 * sizeof(SfpStream),
|
||||||
|
4 /* NuqStream, actually 4.5 */,
|
||||||
|
8 * sizeof(double),
|
||||||
|
8 * sizeof(std::complex<double>),
|
||||||
|
8 * sizeof(hwy::uint128_t)};
|
||||||
|
|
||||||
inline bool EnumValid(Type type) {
|
static inline bool EnumValid(Type type) {
|
||||||
return static_cast<int>(type) >= 0 &&
|
return static_cast<size_t>(type) < kNumTypes;
|
||||||
static_cast<int>(type) <= static_cast<int>(Type::kU128);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a Type enum for the type of the template parameter.
|
// Returns a Type enum for the type of the template parameter.
|
||||||
|
|
@ -236,10 +245,16 @@ Type TypeEnum() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns a string name for the type of the template parameter.
|
static inline size_t TypeBits(Type type) {
|
||||||
|
return kTypeBits[static_cast<int>(type)];
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline const char* TypeName(Type type) {
|
||||||
|
return kTypeStrings[static_cast<int>(type)];
|
||||||
|
}
|
||||||
template <typename PackedT>
|
template <typename PackedT>
|
||||||
const char* TypeName() {
|
const char* TypeName() {
|
||||||
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
|
return TypeName(TypeEnum<PackedT>());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
|
|
@ -248,7 +263,9 @@ constexpr bool IsCompressed() {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the number of `MatT` elements required to store `capacity` values,
|
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||||
// which must not be zero.
|
// which must not be zero. This is only intended to support the extra tables
|
||||||
|
// required for NUQ. `capacity` includes any padding and is `rows * stride`.
|
||||||
|
// Deprecated, replaced by fixup within `MatPtr`. Only used by tests.
|
||||||
template <typename Packed>
|
template <typename Packed>
|
||||||
constexpr size_t CompressedArrayElements(size_t capacity) {
|
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||||
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||||
|
|
|
||||||
|
|
@ -18,10 +18,13 @@
|
||||||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "compression/distortion.h"
|
#include "compression/distortion.h"
|
||||||
|
#include "util/mat.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
|
#include "compression/compress.h"
|
||||||
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||||
|
|
||||||
// Include guard for (potentially) SIMD code.
|
// Include guard for (potentially) SIMD code.
|
||||||
|
|
@ -62,6 +65,52 @@ void ForeachPackedAndRawType() {
|
||||||
ForeachRawType<NuqStream, TestT>();
|
ForeachRawType<NuqStream, TestT>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Generates inputs: deterministic, within max SfpStream range.
|
||||||
|
template <typename MatT>
|
||||||
|
MatStorageT<MatT> GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) {
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
|
||||||
|
MatStorageT<MatT> compressed("mat", extents, MatPadding::kPacked);
|
||||||
|
const float scale = SfpStream::kMax / extents.Area();
|
||||||
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
|
float* HWY_RESTRICT row = raw.Row(r);
|
||||||
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||||
|
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||||
|
row[c] = f;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
|
||||||
|
/*packed_ofs=*/0, pool);
|
||||||
|
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
|
||||||
|
return compressed;
|
||||||
|
}
|
||||||
|
|
||||||
|
// `extents` describes the transposed matrix.
|
||||||
|
template <typename MatT>
|
||||||
|
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
|
||||||
|
MatStorageT<MatT> compressed("trans", extents, MatPadding::kPacked);
|
||||||
|
const float scale = SfpStream::kMax / extents.Area();
|
||||||
|
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||||
|
float* HWY_RESTRICT row = raw.Row(r);
|
||||||
|
for (size_t c = 0; c < extents.cols; c++) {
|
||||||
|
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||||
|
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||||
|
row[c] = f;
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
|
||||||
|
/*packed_ofs=*/0, pool);
|
||||||
|
// Arbitrary value, different from 1, must match `GenerateMat`.
|
||||||
|
compressed.SetScale(0.6f);
|
||||||
|
return compressed;
|
||||||
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -128,10 +128,10 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
||||||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||||
prompt.begin() + pos + num_tokens);
|
prompt.begin() + pos + num_tokens);
|
||||||
KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
|
KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(),
|
||||||
env.MutableConfig().prefill_tbatch_size);
|
env.MutableConfig().prefill_tbatch_size);
|
||||||
float entropy = ComputeCrossEntropy(
|
float entropy = ComputeCrossEntropy(
|
||||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||||
total_entropy += entropy;
|
total_entropy += entropy;
|
||||||
LogSpeedStats(time_start, pos + num_tokens);
|
LogSpeedStats(time_start, pos + num_tokens);
|
||||||
std::string text_slice = env.StringFromTokens(prompt_slice);
|
std::string text_slice = env.StringFromTokens(prompt_slice);
|
||||||
|
|
@ -186,8 +186,8 @@ int main(int argc, char** argv) {
|
||||||
if (!benchmark_args.goldens.Empty()) {
|
if (!benchmark_args.goldens.Empty()) {
|
||||||
const std::string golden_path =
|
const std::string golden_path =
|
||||||
benchmark_args.goldens.path + "/" +
|
benchmark_args.goldens.path + "/" +
|
||||||
gcpp::ModelString(env.GetModel()->Info().model,
|
gcpp::ModelString(env.GetGemma()->Info().model,
|
||||||
env.GetModel()->Info().wrapping) +
|
env.GetGemma()->Info().wrapping) +
|
||||||
".txt";
|
".txt";
|
||||||
return BenchmarkGoldens(env, golden_path);
|
return BenchmarkGoldens(env, golden_path);
|
||||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||||
|
|
|
||||||
|
|
@ -18,27 +18,20 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
#include <time.h>
|
#include <time.h>
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <memory>
|
|
||||||
#include <ostream>
|
#include <ostream>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
#include "compression/shared.h" // TypeName
|
||||||
#include "compression/compress.h" // TypeName
|
|
||||||
#include "evals/cross_entropy.h"
|
#include "evals/cross_entropy.h"
|
||||||
#include "gemma/common.h" // StringFromType
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "gemma/kv_cache.h"
|
#include "gemma/gemma_args.h"
|
||||||
#include "util/app.h"
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
#include "util/args.h"
|
#include "util/threading_context.h"
|
||||||
#include "util/threading.h"
|
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/contrib/thread_pool/topology.h"
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/per_target.h" // VectorBytes
|
#include "hwy/per_target.h" // DispatchedTarget
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -54,11 +47,9 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv::GemmaEnv(const ThreadingArgs& threading_args,
|
||||||
const AppArgs& app)
|
const LoaderArgs& loader, const InferenceArgs& inference)
|
||||||
: topology_(CreateTopology(app)),
|
: env_(MakeMatMulEnv(threading_args)) {
|
||||||
pools_(CreatePools(topology_, app)),
|
|
||||||
env_(topology_, pools_) {
|
|
||||||
InferenceArgs mutable_inference = inference;
|
InferenceArgs mutable_inference = inference;
|
||||||
AbortIfInvalidArgs(mutable_inference);
|
AbortIfInvalidArgs(mutable_inference);
|
||||||
LoaderArgs mutable_loader = loader;
|
LoaderArgs mutable_loader = loader;
|
||||||
|
|
@ -67,10 +58,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "Loading model...\n");
|
fprintf(stderr, "Loading model...\n");
|
||||||
model_ = AllocateGemma(mutable_loader, env_);
|
gemma_ = AllocateGemma(mutable_loader, env_);
|
||||||
// Only allocate one for starters because GenerateBatch might not be called.
|
// Only allocate one for starters because GenerateBatch might not be called.
|
||||||
kv_caches_.resize(1);
|
kv_caches_.resize(1);
|
||||||
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
|
kv_caches_[0] = KVCache::Create(gemma_->GetModelConfig(),
|
||||||
inference.prefill_tbatch_size);
|
inference.prefill_tbatch_size);
|
||||||
}
|
}
|
||||||
InitGenerator(inference, gen_);
|
InitGenerator(inference, gen_);
|
||||||
|
|
@ -78,24 +69,13 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||||
.max_generated_tokens = inference.max_generated_tokens,
|
.max_generated_tokens = inference.max_generated_tokens,
|
||||||
.temperature = inference.temperature,
|
.temperature = inference.temperature,
|
||||||
.gen = &gen_,
|
.gen = &gen_,
|
||||||
.verbosity = app.verbosity,
|
.verbosity = inference.verbosity,
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Internal init must run before the GemmaEnv ctor above, hence it cannot occur
|
|
||||||
// in the argv ctor below because its body runs *after* the delegating ctor.
|
|
||||||
// This helper function takes care of the init, and could be applied to any of
|
|
||||||
// the *Args classes, it does not matter which.
|
|
||||||
static AppArgs MakeAppArgs(int argc, char** argv) {
|
|
||||||
{ // So that indentation matches expectations.
|
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
return AppArgs(argc, argv);
|
|
||||||
}
|
|
||||||
|
|
||||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
: GemmaEnv(ThreadingArgs(argc, argv), LoaderArgs(argc, argv),
|
||||||
MakeAppArgs(argc, argv)) {}
|
InferenceArgs(argc, argv)) {}
|
||||||
|
|
||||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
QueryResult result;
|
QueryResult result;
|
||||||
|
|
@ -117,7 +97,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||||
}
|
}
|
||||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
timing_info);
|
timing_info);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
@ -127,7 +107,7 @@ void GemmaEnv::QueryModel(
|
||||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||||
runtime_config_.stream_token = stream_token;
|
runtime_config_.stream_token = stream_token;
|
||||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||||
timing_info);
|
timing_info);
|
||||||
runtime_config_.stream_token = previous_stream_token;
|
runtime_config_.stream_token = previous_stream_token;
|
||||||
}
|
}
|
||||||
|
|
@ -142,7 +122,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
int token, float) {
|
int token, float) {
|
||||||
std::string token_text;
|
std::string token_text;
|
||||||
HWY_ASSERT(
|
HWY_ASSERT(
|
||||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
gemma_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
res[query_index].response.append(token_text);
|
res[query_index].response.append(token_text);
|
||||||
res[query_index].tokens_generated += 1;
|
res[query_index].tokens_generated += 1;
|
||||||
if (res[query_index].tokens_generated ==
|
if (res[query_index].tokens_generated ==
|
||||||
|
|
@ -164,7 +144,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
}
|
}
|
||||||
for (size_t i = 1; i < num_queries; ++i) {
|
for (size_t i = 1; i < num_queries; ++i) {
|
||||||
if (kv_caches_[i].seq_len == 0) {
|
if (kv_caches_[i].seq_len == 0) {
|
||||||
kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
|
kv_caches_[i] = KVCache::Create(gemma_->GetModelConfig(),
|
||||||
runtime_config_.prefill_tbatch_size);
|
runtime_config_.prefill_tbatch_size);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -172,7 +152,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||||
runtime_config_.batch_stream_token = batch_stream_token;
|
runtime_config_.batch_stream_token = batch_stream_token;
|
||||||
std::vector<size_t> queries_pos(num_queries, 0);
|
std::vector<size_t> queries_pos(num_queries, 0);
|
||||||
model_->GenerateBatch(runtime_config_, queries_prompt,
|
gemma_->GenerateBatch(runtime_config_, queries_prompt,
|
||||||
QueriesPos(queries_pos.data(), num_queries),
|
QueriesPos(queries_pos.data(), num_queries),
|
||||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||||
return res;
|
return res;
|
||||||
|
|
@ -203,7 +183,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
||||||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||||
std::vector<int> prompt = Tokenize(input);
|
std::vector<int> prompt = Tokenize(input);
|
||||||
prompt.insert(prompt.begin(), BOS_ID);
|
prompt.insert(prompt.begin(), BOS_ID);
|
||||||
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
|
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
|
||||||
MutableKVCache(),
|
MutableKVCache(),
|
||||||
/*verbosity=*/0) /
|
/*verbosity=*/0) /
|
||||||
static_cast<int>(input.size());
|
static_cast<int>(input.size());
|
||||||
|
|
@ -236,17 +216,36 @@ std::string CacheString() {
|
||||||
return buf;
|
return buf;
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
static constexpr const char* CompiledConfig() {
|
||||||
const BoundedTopology& topology, NestedPools& pools) {
|
if constexpr (HWY_IS_ASAN) {
|
||||||
loader.Print(app.verbosity);
|
return "asan";
|
||||||
inference.Print(app.verbosity);
|
} else if constexpr (HWY_IS_MSAN) {
|
||||||
app.Print(app.verbosity);
|
return "msan";
|
||||||
|
} else if constexpr (HWY_IS_TSAN) {
|
||||||
|
return "tsan";
|
||||||
|
} else if constexpr (HWY_IS_HWASAN) {
|
||||||
|
return "hwasan";
|
||||||
|
} else if constexpr (HWY_IS_UBSAN) {
|
||||||
|
return "ubsan";
|
||||||
|
} else if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
return "dbg";
|
||||||
|
} else {
|
||||||
|
return "opt";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (app.verbosity >= 2) {
|
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
|
InferenceArgs& inference) {
|
||||||
|
threading.Print(inference.verbosity);
|
||||||
|
loader.Print(inference.verbosity);
|
||||||
|
inference.Print(inference.verbosity);
|
||||||
|
|
||||||
|
if (inference.verbosity >= 2) {
|
||||||
time_t now = time(nullptr);
|
time_t now = time(nullptr);
|
||||||
char* dt = ctime(&now); // NOLINT
|
char* dt = ctime(&now); // NOLINT
|
||||||
char cpu100[100] = "unknown";
|
char cpu100[100] = "unknown";
|
||||||
(void)hwy::platform::GetCpuString(cpu100);
|
(void)hwy::platform::GetCpuString(cpu100);
|
||||||
|
const ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||||
|
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"Date & Time : %s" // dt includes \n
|
"Date & Time : %s" // dt includes \n
|
||||||
|
|
@ -254,16 +253,18 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||||
"CPU topology : %s, %s, %s\n"
|
"CPU topology : %s, %s, %s\n"
|
||||||
"Instruction set : %s (%zu bits)\n"
|
"Instruction set : %s (%zu bits)\n"
|
||||||
"Compiled config : %s\n"
|
"Compiled config : %s\n"
|
||||||
"Weight Type : %s\n"
|
"Memory MiB : %4zu, %4zu free\n"
|
||||||
"EmbedderInput Type : %s\n",
|
"Weight Type : %s\n",
|
||||||
dt, cpu100, topology.TopologyString(), pools.PinString(),
|
dt, cpu100, ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||||
hwy::VectorBytes() * 8, CompiledConfig(),
|
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
|
||||||
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());
|
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB(),
|
||||||
|
StringFromType(loader.Info().weight));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
|
InferenceArgs& inference) {
|
||||||
std::cerr
|
std::cerr
|
||||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||||
"==========================================================\n\n"
|
"==========================================================\n\n"
|
||||||
|
|
@ -272,16 +273,16 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
" --tokenizer\n"
|
" --tokenizer\n"
|
||||||
" --weights\n"
|
" --weights\n"
|
||||||
" --model,\n"
|
" --model,\n"
|
||||||
" or with the newer weights format, specify just:\n"
|
" or with the single-file weights format, specify just:\n"
|
||||||
" --weights\n";
|
" --weights\n";
|
||||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||||
|
std::cerr << "\n*Threading Arguments*\n\n";
|
||||||
|
threading.Help();
|
||||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||||
loader.Help();
|
loader.Help();
|
||||||
std::cerr << "\n*Inference Arguments*\n\n";
|
std::cerr << "\n*Inference Arguments*\n\n";
|
||||||
inference.Help();
|
inference.Help();
|
||||||
std::cerr << "\n*Application Arguments*\n\n";
|
|
||||||
app.Help();
|
|
||||||
std::cerr << "\n";
|
std::cerr << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -24,9 +24,9 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
#include "gemma/gemma_args.h"
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/app.h"
|
#include "util/threading_context.h"
|
||||||
#include "util/threading.h"
|
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -46,8 +46,10 @@ class GemmaEnv {
|
||||||
public:
|
public:
|
||||||
// Calls the other constructor with *Args arguments initialized from argv.
|
// Calls the other constructor with *Args arguments initialized from argv.
|
||||||
GemmaEnv(int argc, char** argv);
|
GemmaEnv(int argc, char** argv);
|
||||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
|
||||||
const AppArgs& app);
|
const InferenceArgs& inference);
|
||||||
|
|
||||||
|
MatMulEnv& Env() { return env_; }
|
||||||
|
|
||||||
size_t MaxGeneratedTokens() const {
|
size_t MaxGeneratedTokens() const {
|
||||||
return runtime_config_.max_generated_tokens;
|
return runtime_config_.max_generated_tokens;
|
||||||
|
|
@ -58,7 +60,7 @@ class GemmaEnv {
|
||||||
|
|
||||||
std::vector<int> Tokenize(const std::string& input) const {
|
std::vector<int> Tokenize(const std::string& input) const {
|
||||||
std::vector<int> tokens;
|
std::vector<int> tokens;
|
||||||
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
|
HWY_ASSERT(gemma_->Tokenizer().Encode(input, &tokens));
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -69,13 +71,13 @@ class GemmaEnv {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(),
|
return gcpp::WrapAndTokenize(gemma_->Tokenizer(), gemma_->ChatTemplate(),
|
||||||
model_->Info(), 0, input);
|
gemma_->Info(), 0, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||||
std::string string;
|
std::string string;
|
||||||
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
HWY_ASSERT(gemma_->Tokenizer().Decode(tokens, &string));
|
||||||
return string;
|
return string;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -99,7 +101,7 @@ class GemmaEnv {
|
||||||
float CrossEntropy(const std::string& input);
|
float CrossEntropy(const std::string& input);
|
||||||
|
|
||||||
// Returns nullptr if the model failed to load.
|
// Returns nullptr if the model failed to load.
|
||||||
Gemma* GetModel() const { return model_.get(); }
|
Gemma* GetGemma() const { return gemma_.get(); }
|
||||||
|
|
||||||
int Verbosity() const { return runtime_config_.verbosity; }
|
int Verbosity() const { return runtime_config_.verbosity; }
|
||||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||||
|
|
@ -107,11 +109,9 @@ class GemmaEnv {
|
||||||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
BoundedTopology topology_;
|
|
||||||
NestedPools pools_; // Thread pool.
|
|
||||||
MatMulEnv env_;
|
MatMulEnv env_;
|
||||||
std::mt19937 gen_; // Random number generator.
|
std::mt19937 gen_; // Random number generator.
|
||||||
std::unique_ptr<Gemma> model_;
|
std::unique_ptr<Gemma> gemma_;
|
||||||
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
||||||
RuntimeConfig runtime_config_;
|
RuntimeConfig runtime_config_;
|
||||||
};
|
};
|
||||||
|
|
@ -119,9 +119,10 @@ class GemmaEnv {
|
||||||
// Logs the inference speed in tokens/sec.
|
// Logs the inference speed in tokens/sec.
|
||||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||||
|
|
||||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
const BoundedTopology& topology, NestedPools& pools);
|
InferenceArgs& inference);
|
||||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
|
InferenceArgs& inference);
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,8 +51,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
// Using the turn structure worsens results sometimes.
|
// Using the turn structure worsens results sometimes.
|
||||||
// However, some models need the turn structure to work.
|
// However, some models need the turn structure to work.
|
||||||
// It would be good to make these tests more consistent.
|
// It would be good to make these tests more consistent.
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||||
replies.push_back(result.response);
|
replies.push_back(result.response);
|
||||||
}
|
}
|
||||||
|
|
@ -76,7 +76,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
|
|
||||||
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
|
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
|
|
||||||
std::vector<std::string> inputs;
|
std::vector<std::string> inputs;
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
|
|
||||||
|
|
@ -50,8 +50,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
// Using the turn structure worsens results sometimes.
|
// Using the turn structure worsens results sometimes.
|
||||||
// However, some models need the turn structure to work.
|
// However, some models need the turn structure to work.
|
||||||
// It would be good to make these tests more consistent.
|
// It would be good to make these tests more consistent.
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||||
std::string mutable_prompt = prompt;
|
std::string mutable_prompt = prompt;
|
||||||
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||||
return result.response;
|
return result.response;
|
||||||
|
|
@ -71,8 +71,8 @@ class GemmaTest : public ::testing::Test {
|
||||||
// Using the turn structure worsens results sometimes.
|
// Using the turn structure worsens results sometimes.
|
||||||
// However, some models need the turn structure to work.
|
// However, some models need the turn structure to work.
|
||||||
// It would be good to make these tests more consistent.
|
// It would be good to make these tests more consistent.
|
||||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||||
replies.push_back(result.response);
|
replies.push_back(result.response);
|
||||||
}
|
}
|
||||||
|
|
@ -96,7 +96,7 @@ class GemmaTest : public ::testing::Test {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||||
if (batch) {
|
if (batch) {
|
||||||
std::vector<std::string> inputs;
|
std::vector<std::string> inputs;
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
|
@ -155,8 +155,8 @@ TEST_F(GemmaTest, Arithmetic) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, Multiturn) {
|
TEST_F(GemmaTest, Multiturn) {
|
||||||
Gemma* model = s_env->GetModel();
|
Gemma* model = s_env->GetGemma();
|
||||||
ASSERT_NE(model, nullptr);
|
HWY_ASSERT(model != nullptr);
|
||||||
size_t abs_pos = 0;
|
size_t abs_pos = 0;
|
||||||
std::string response;
|
std::string response;
|
||||||
auto stream_token = [&](int token, float) {
|
auto stream_token = [&](int token, float) {
|
||||||
|
|
@ -239,12 +239,12 @@ static const char kGettysburg[] = {
|
||||||
"people, for the people, shall not perish from the earth.\n"};
|
"people, for the people, shall not perish from the earth.\n"};
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||||
static const char kSmall[] =
|
static const char kSmall[] =
|
||||||
"The capital of Hungary is Budapest which is located in Europe.";
|
"The capital of Hungary is Budapest which is located in Europe.";
|
||||||
float entropy = s_env->CrossEntropy(kSmall);
|
float entropy = s_env->CrossEntropy(kSmall);
|
||||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetGemma()->Info().model) {
|
||||||
case gcpp::Model::GEMMA_2B:
|
case gcpp::Model::GEMMA_2B:
|
||||||
// 2B v.1 and v.1.1 produce slightly different results.
|
// 2B v.1 and v.1.1 produce slightly different results.
|
||||||
EXPECT_NEAR(entropy, 2.6f, 0.2f);
|
EXPECT_NEAR(entropy, 2.6f, 0.2f);
|
||||||
|
|
@ -272,10 +272,10 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetGemma()->Info().model) {
|
||||||
case gcpp::Model::GEMMA_2B:
|
case gcpp::Model::GEMMA_2B:
|
||||||
// 2B v.1 and v.1.1 produce slightly different results.
|
// 2B v.1 and v.1.1 produce slightly different results.
|
||||||
EXPECT_NEAR(entropy, 1.9f, 0.2f);
|
EXPECT_NEAR(entropy, 1.9f, 0.2f);
|
||||||
|
|
@ -303,10 +303,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetGemma()->Info().model) {
|
||||||
case gcpp::Model::GEMMA_2B:
|
case gcpp::Model::GEMMA_2B:
|
||||||
// 2B v.1 and v.1.1 produce slightly different results.
|
// 2B v.1 and v.1.1 produce slightly different results.
|
||||||
EXPECT_NEAR(entropy, 1.1f, 0.1f);
|
EXPECT_NEAR(entropy, 1.1f, 0.1f);
|
||||||
|
|
|
||||||
|
|
@ -89,7 +89,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
"A", "B", "C", "D", //
|
"A", "B", "C", "D", //
|
||||||
" A", " B", " C", " D", //
|
" A", " B", " C", " D", //
|
||||||
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
|
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
|
||||||
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
|
const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings);
|
||||||
|
|
||||||
for (auto sample : json_data["samples"]) {
|
for (auto sample : json_data["samples"]) {
|
||||||
const int id = sample["i"];
|
const int id = sample["i"];
|
||||||
|
|
@ -131,7 +131,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
||||||
.verbosity = env.Verbosity(),
|
.verbosity = env.Verbosity(),
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
};
|
};
|
||||||
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
|
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||||
env.MutableKVCache(), timing_info);
|
env.MutableKVCache(), timing_info);
|
||||||
|
|
||||||
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
||||||
|
|
|
||||||
|
|
@ -10,13 +10,11 @@ cc_binary(
|
||||||
name = "hello_world",
|
name = "hello_world",
|
||||||
srcs = ["run.cc"],
|
srcs = ["run.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
# Placeholder for internal dep, do not remove.,
|
|
||||||
"//:app",
|
|
||||||
"//:args",
|
"//:args",
|
||||||
|
"//:gemma_args",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//:threading",
|
"//:threading_context",
|
||||||
"//:tokenizer",
|
"//:tokenizer",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:thread_pool",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
||||||
example:
|
example:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
./hello_world --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
|
||||||
```
|
```
|
||||||
|
|
||||||
Should print a greeting to the terminal:
|
Should print a greeting to the terminal:
|
||||||
|
|
|
||||||
|
|
@ -23,23 +23,17 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
#include "gemma/gemma_args.h" // LoaderArgs
|
||||||
#include "gemma/tokenizer.h"
|
#include "gemma/tokenizer.h"
|
||||||
#include "util/app.h" // LoaderArgs
|
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
{
|
gcpp::ThreadingArgs threading(argc, argv);
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
gcpp::AppArgs app(argc, argv);
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
loader.Help();
|
loader.Help();
|
||||||
return 0;
|
return 0;
|
||||||
|
|
@ -53,14 +47,14 @@ int main(int argc, char** argv) {
|
||||||
for (int arg = 0; arg < argc; ++arg) {
|
for (int arg = 0; arg < argc; ++arg) {
|
||||||
// Find a --reject flag and consume everything after it.
|
// Find a --reject flag and consume everything after it.
|
||||||
if (strcmp(argv[arg], "--reject") == 0) {
|
if (strcmp(argv[arg], "--reject") == 0) {
|
||||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
while (++arg < argc) {
|
||||||
|
reject_tokens.insert(atoi(argv[arg])); // NOLINT
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
gcpp::BoundedTopology topology(gcpp::CreateTopology(app));
|
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||||
gcpp::NestedPools pools = gcpp::CreatePools(topology, app);
|
|
||||||
gcpp::MatMulEnv env(topology, pools);
|
|
||||||
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
|
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
|
||||||
gcpp::KVCache kv_cache =
|
gcpp::KVCache kv_cache =
|
||||||
gcpp::KVCache::Create(model.GetModelConfig(),
|
gcpp::KVCache::Create(model.GetModelConfig(),
|
||||||
|
|
|
||||||
|
|
@ -10,10 +10,10 @@ cc_library(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
hdrs = ["gemma.hpp"],
|
hdrs = ["gemma.hpp"],
|
||||||
deps = [
|
deps = [
|
||||||
"//:app",
|
"//:gemma_args",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//:ops",
|
"//:ops",
|
||||||
"//:threading",
|
"//:threading_context",
|
||||||
"//:tokenizer",
|
"//:tokenizer",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
],
|
],
|
||||||
|
|
@ -24,15 +24,6 @@ cc_binary(
|
||||||
srcs = ["run.cc"],
|
srcs = ["run.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
":gemma",
|
":gemma",
|
||||||
# Placeholder for internal dep, do not remove.,
|
"//:gemma_args",
|
||||||
"//:app",
|
|
||||||
"//:args",
|
|
||||||
"//:common",
|
|
||||||
"//:gemma_lib",
|
|
||||||
"//:ops",
|
|
||||||
"//:threading",
|
|
||||||
"//:tokenizer",
|
|
||||||
"@highway//:hwy",
|
|
||||||
"@highway//:thread_pool",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
||||||
example:
|
example:
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
|
||||||
```
|
```
|
||||||
|
|
||||||
Should print a greeting to the terminal:
|
Should print a greeting to the terminal:
|
||||||
|
|
|
||||||
|
|
@ -24,39 +24,22 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "third_party/gemma_cpp/gemma/gemma.h"
|
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||||
|
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
|
||||||
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||||
#include "third_party/gemma_cpp/ops/matmul.h"
|
#include "third_party/gemma_cpp/ops/matmul.h"
|
||||||
#include "third_party/gemma_cpp/util/app.h" // LoaderArgs
|
#include "third_party/gemma_cpp/util/threading_context.h"
|
||||||
#include "third_party/gemma_cpp/util/threading.h"
|
|
||||||
#include "third_party/highway/hwy/base.h"
|
#include "third_party/highway/hwy/base.h"
|
||||||
|
|
||||||
class SimplifiedGemma {
|
class SimplifiedGemma {
|
||||||
public:
|
public:
|
||||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(),
|
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||||
const gcpp::AppArgs& app = gcpp::AppArgs())
|
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||||
: loader_(loader),
|
: loader_(loader),
|
||||||
|
threading_(threading),
|
||||||
inference_(inference),
|
inference_(inference),
|
||||||
app_(app),
|
env_(MakeMatMulEnv(threading_)),
|
||||||
topology_(gcpp::CreateTopology(app_)),
|
|
||||||
pools_(gcpp::CreatePools(topology_, app_)),
|
|
||||||
env_(topology_, pools_),
|
|
||||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
model_(gcpp::CreateGemma(loader_, env_)) {
|
||||||
Init();
|
|
||||||
}
|
|
||||||
|
|
||||||
SimplifiedGemma(int argc, char** argv)
|
|
||||||
: loader_(argc, argv, /*validate=*/true),
|
|
||||||
inference_(argc, argv),
|
|
||||||
app_(argc, argv),
|
|
||||||
topology_(gcpp::CreateTopology(app_)),
|
|
||||||
pools_(gcpp::CreatePools(topology_, app_)),
|
|
||||||
env_(topology_, pools_),
|
|
||||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
|
||||||
Init();
|
|
||||||
}
|
|
||||||
|
|
||||||
void Init() {
|
|
||||||
// Instantiate model and KV Cache
|
// Instantiate model and KV Cache
|
||||||
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
|
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
|
||||||
inference_.prefill_tbatch_size);
|
inference_.prefill_tbatch_size);
|
||||||
|
|
@ -66,6 +49,11 @@ class SimplifiedGemma {
|
||||||
gen_.seed(rd());
|
gen_.seed(rd());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SimplifiedGemma(int argc, char** argv)
|
||||||
|
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv, /*validate=*/true),
|
||||||
|
gcpp::ThreadingArgs(argc, argv),
|
||||||
|
gcpp::InferenceArgs(argc, argv)) {}
|
||||||
|
|
||||||
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||||
float temperature = 0.7,
|
float temperature = 0.7,
|
||||||
const std::set<int>& reject_tokens = {}) {
|
const std::set<int>& reject_tokens = {}) {
|
||||||
|
|
@ -107,10 +95,8 @@ class SimplifiedGemma {
|
||||||
|
|
||||||
private:
|
private:
|
||||||
gcpp::LoaderArgs loader_;
|
gcpp::LoaderArgs loader_;
|
||||||
|
gcpp::ThreadingArgs threading_;
|
||||||
gcpp::InferenceArgs inference_;
|
gcpp::InferenceArgs inference_;
|
||||||
gcpp::AppArgs app_;
|
|
||||||
gcpp::BoundedTopology topology_;
|
|
||||||
gcpp::NestedPools pools_;
|
|
||||||
gcpp::MatMulEnv env_;
|
gcpp::MatMulEnv env_;
|
||||||
gcpp::Gemma model_;
|
gcpp::Gemma model_;
|
||||||
gcpp::KVCache kv_cache_;
|
gcpp::KVCache kv_cache_;
|
||||||
|
|
|
||||||
|
|
@ -17,15 +17,10 @@
|
||||||
|
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
|
||||||
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||||
#include "util/app.h" // LoaderArgs
|
#include "gemma/gemma_args.h" // LoaderArgs
|
||||||
|
|
||||||
int main(int argc, char** argv) {
|
int main(int argc, char** argv) {
|
||||||
{
|
|
||||||
// Placeholder for internal init, do not modify.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||||
// necessary flags.
|
// necessary flags.
|
||||||
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
|
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
|
||||||
|
|
@ -35,12 +30,12 @@ int main(int argc, char** argv) {
|
||||||
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
||||||
// "model_identifier");
|
// "model_identifier");
|
||||||
|
|
||||||
// Optional: InferenceArgs and AppArgs can be passed in as well. If not
|
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
|
||||||
// specified, default values will be used.
|
// specified, default values will be used.
|
||||||
//
|
//
|
||||||
// gcpp::InferenceArgs inference(argc, argv);
|
// gcpp::InferenceArgs inference(argc, argv);
|
||||||
// gcpp::AppArgs app(argc, argv);
|
// gcpp::ThreadingArgs threading(argc, argv);
|
||||||
// SimplifiedGemma gemma(loader, inference, app);
|
// SimplifiedGemma gemma(loader, threading, inference);
|
||||||
|
|
||||||
SimplifiedGemma gemma(loader);
|
SimplifiedGemma gemma(loader);
|
||||||
std::string prompt = "Write a greeting to the world.";
|
std::string prompt = "Write a greeting to the world.";
|
||||||
|
|
|
||||||
|
|
@ -18,14 +18,12 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "compression/shared.h" // BF16
|
#include "gemma/configs.h" // ModelConfig
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
#include "ops/ops.h" // CreateInvTimescale
|
#include "ops/ops.h" // CreateInvTimescale
|
||||||
#include "util/allocator.h" // RowVectorBatch
|
#include "util/allocator.h" // Allocator
|
||||||
#include "util/threading.h"
|
#include "util/basics.h" // BF16
|
||||||
#include "hwy/base.h" // HWY_DASSERT
|
#include "util/mat.h" // RowVectorBatch
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -74,6 +72,8 @@ struct Activations {
|
||||||
size_t cache_pos_size = 0;
|
size_t cache_pos_size = 0;
|
||||||
|
|
||||||
void Allocate(size_t batch_size, MatMulEnv* env) {
|
void Allocate(size_t batch_size, MatMulEnv* env) {
|
||||||
|
const Allocator2& allocator = env->ctx.allocator;
|
||||||
|
|
||||||
post_qk = layer_config.post_qk;
|
post_qk = layer_config.post_qk;
|
||||||
const size_t model_dim = weights_config.model_dim;
|
const size_t model_dim = weights_config.model_dim;
|
||||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||||
|
|
@ -81,36 +81,45 @@ struct Activations {
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
const size_t heads = layer_config.heads;
|
const size_t heads = layer_config.heads;
|
||||||
|
|
||||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
x = RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
q = RowVectorBatch<float>(
|
q = RowVectorBatch<float>(
|
||||||
Extents2D(batch_size, heads * layer_config.QStride()));
|
allocator, Extents2D(batch_size, heads * layer_config.QStride()));
|
||||||
if (vocab_size > 0) {
|
if (vocab_size > 0) {
|
||||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
logits =
|
||||||
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, vocab_size));
|
||||||
}
|
}
|
||||||
|
|
||||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
pre_att_rms_out =
|
||||||
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
att = RowVectorBatch<float>(
|
att = RowVectorBatch<float>(
|
||||||
Extents2D(batch_size, heads * weights_config.seq_len));
|
allocator, Extents2D(batch_size, heads * weights_config.seq_len));
|
||||||
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
|
att_out = RowVectorBatch<float>(allocator,
|
||||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
Extents2D(batch_size, heads * qkv_dim));
|
||||||
|
att_sums =
|
||||||
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
|
|
||||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
bf_pre_ffw_rms_out =
|
||||||
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
RowVectorBatch<BF16>(allocator, Extents2D(batch_size, model_dim));
|
||||||
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
C1 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
|
||||||
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
C2 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
|
||||||
|
ffw_out =
|
||||||
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
|
|
||||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||||
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
griffin_x =
|
||||||
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
griffin_y =
|
||||||
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
|
griffin_gate_x =
|
||||||
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
griffin_multiplier =
|
griffin_multiplier =
|
||||||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||||
}
|
}
|
||||||
|
|
||||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim,
|
||||||
post_qk == PostQKType::HalfRope);
|
post_qk == PostQKType::HalfRope);
|
||||||
inv_timescale_global =
|
inv_timescale_global = CreateInvTimescale(
|
||||||
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||||
|
|
||||||
this->env = env;
|
this->env = env;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -17,13 +17,13 @@
|
||||||
|
|
||||||
#include <math.h> // sqrtf
|
#include <math.h> // sqrtf
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm> // std::min
|
#include <algorithm> // std::min
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "gemma/activations.h"
|
#include "gemma/activations.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
|
@ -32,11 +32,9 @@
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/mat.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/bit_set.h"
|
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
@ -82,7 +80,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
const KVCaches& kv_caches) {
|
const KVCaches& kv_caches) {
|
||||||
PROFILER_ZONE("Gen.Griffin");
|
PROFILER_ZONE("Gen.Griffin");
|
||||||
KVCache& kv_cache = kv_caches[0];
|
KVCache& kv_cache = kv_caches[0];
|
||||||
hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0);
|
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
using D = hn::ScalableTag<float>;
|
using D = hn::ScalableTag<float>;
|
||||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||||
|
|
@ -96,8 +94,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
||||||
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
||||||
activations.pre_att_rms_out.Batch(batch_idx),
|
activations.pre_att_rms_out.Batch(batch_idx),
|
||||||
/*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(),
|
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||||
/*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(),
|
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||||
/*out0=*/x, /*out1=*/y, pool);
|
/*out0=*/x, /*out1=*/y, pool);
|
||||||
Gelu(y, model_dim);
|
Gelu(y, model_dim);
|
||||||
}
|
}
|
||||||
|
|
@ -121,15 +119,15 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||||
auto xv = hn::Load(df, x + i);
|
auto xv = hn::Load(df, x + i);
|
||||||
auto accum0 =
|
auto accum0 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i);
|
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
||||||
auto accum1 = hn::Zero(df);
|
auto accum1 = hn::Zero(df);
|
||||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||||
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
||||||
auto wv0 =
|
auto wv0 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
|
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||||
(conv_1d_width - 1 - 2 * l) * model_dim + i);
|
(conv_1d_width - 1 - 2 * l) * model_dim + i);
|
||||||
auto wv1 =
|
auto wv1 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
|
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||||
(conv_1d_width - 2 - 2 * l) * model_dim + i);
|
(conv_1d_width - 2 - 2 * l) * model_dim + i);
|
||||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||||
|
|
@ -156,9 +154,9 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
TwoOfsMatVecAddLoop(
|
TwoOfsMatVecAddLoop(
|
||||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||||
kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset,
|
kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset,
|
||||||
/*add0=*/layer_weights->griffin.gate_biases.data_scale1() +
|
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||||
head_offset,
|
head_offset,
|
||||||
/*add1=*/layer_weights->griffin.gate_biases.data_scale1() +
|
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||||
model_dim + head_offset,
|
model_dim + head_offset,
|
||||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||||
|
|
@ -166,7 +164,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||||
layer_weights->griffin.a.data_scale1() + head_offset,
|
layer_weights->griffin.a.PackedScale1() + head_offset,
|
||||||
fn_mul);
|
fn_mul);
|
||||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||||
fn_mul);
|
fn_mul);
|
||||||
|
|
@ -198,7 +196,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
||||||
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||||
float* out_ptr = activations.att_sums.Batch(batch_idx);
|
float* out_ptr = activations.att_sums.Batch(batch_idx);
|
||||||
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
||||||
layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr,
|
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
||||||
pool);
|
pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -253,7 +251,7 @@ class GemmaAttention {
|
||||||
|
|
||||||
const auto pre_att_rms_out =
|
const auto pre_att_rms_out =
|
||||||
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
|
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
|
||||||
auto w_q1 = layer_weights_.qkv_einsum_w.data()
|
auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr()
|
||||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
|
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
|
||||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
|
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
|
||||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||||
|
|
@ -265,15 +263,20 @@ class GemmaAttention {
|
||||||
const size_t w1_rows = heads * layer_config_.QStride();
|
const size_t w1_rows = heads * layer_config_.QStride();
|
||||||
w_q1.ShrinkRows(w1_rows);
|
w_q1.ShrinkRows(w1_rows);
|
||||||
MatMul(pre_att_rms_out, w_q1,
|
MatMul(pre_att_rms_out, w_q1,
|
||||||
/*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q));
|
/*add=*/nullptr, *activations_.env,
|
||||||
|
RowPtrFromBatch(allocator_, activations_.q));
|
||||||
|
|
||||||
if (is_mha_) {
|
if (is_mha_) {
|
||||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||||
} else {
|
} else {
|
||||||
auto w_q2 = layer_weights_.qkv_einsum_w.data()
|
decltype(w_q1) w_q2;
|
||||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w,
|
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
||||||
w1_rows * model_dim)
|
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w);
|
||||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
// Skip first half of the matrix.
|
||||||
|
w_q2.ofs = w_q2.Row(w1_rows);
|
||||||
|
} else {
|
||||||
|
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
||||||
|
}
|
||||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||||
w_q2.ShrinkRows(w_rows_kv_cols);
|
w_q2.ShrinkRows(w_rows_kv_cols);
|
||||||
|
|
@ -285,7 +288,7 @@ class GemmaAttention {
|
||||||
const size_t kv_ofs =
|
const size_t kv_ofs =
|
||||||
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
|
||||||
kv_rows.SetStride(cache_pos_size_);
|
kv_rows.SetStride(cache_pos_size_);
|
||||||
MatMul(pre_att_rms_out, w_q2,
|
MatMul(pre_att_rms_out, w_q2,
|
||||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||||
|
|
@ -302,7 +305,7 @@ class GemmaAttention {
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||||
if (layer_weights_.qkv_einsum_w.data()) {
|
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
||||||
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
|
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
|
||||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -336,8 +339,8 @@ class GemmaAttention {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply further processing to K.
|
// Apply further processing to K.
|
||||||
if (layer_weights_.key_norm_scale.data()) {
|
if (layer_weights_.key_norm_scale.HasPtr()) {
|
||||||
RMSNormInplace(layer_weights_.key_norm_scale.data(), kv,
|
RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv,
|
||||||
qkv_dim);
|
qkv_dim);
|
||||||
}
|
}
|
||||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||||
|
|
@ -427,8 +430,8 @@ class GemmaAttention {
|
||||||
|
|
||||||
// Apply rope and scaling to Q.
|
// Apply rope and scaling to Q.
|
||||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||||
if (layer_weights_.query_norm_scale.data()) {
|
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||||
RMSNormInplace(layer_weights_.query_norm_scale.data(), q,
|
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q,
|
||||||
qkv_dim);
|
qkv_dim);
|
||||||
}
|
}
|
||||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||||
|
|
@ -473,17 +476,18 @@ class GemmaAttention {
|
||||||
HWY_DASSERT(layer_config_.model_dim > 0);
|
HWY_DASSERT(layer_config_.model_dim > 0);
|
||||||
HWY_DASSERT(layer_config_.heads > 0);
|
HWY_DASSERT(layer_config_.heads > 0);
|
||||||
HWY_DASSERT(layer_config_.qkv_dim > 0);
|
HWY_DASSERT(layer_config_.qkv_dim > 0);
|
||||||
HWY_DASSERT(layer_weights_.att_weights.data() != nullptr);
|
HWY_DASSERT(layer_weights_.att_weights.HasPtr());
|
||||||
HWY_DASSERT(activations_.att_out.All() != nullptr);
|
HWY_DASSERT(activations_.att_out.All() != nullptr);
|
||||||
HWY_DASSERT(activations_.att_sums.All() != nullptr);
|
HWY_DASSERT(activations_.att_sums.All() != nullptr);
|
||||||
|
|
||||||
const float* add =
|
const float* add =
|
||||||
layer_weights_.layer_config.softmax_attn_output_biases
|
layer_weights_.layer_config.softmax_attn_output_biases
|
||||||
? layer_weights_.attention_output_biases.data_scale1()
|
? layer_weights_.attention_output_biases.PackedScale1()
|
||||||
: nullptr;
|
: nullptr;
|
||||||
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
||||||
ConstMatFromWeights(layer_weights_.att_weights), add,
|
ConstMatFromWeights(layer_weights_.att_weights), add,
|
||||||
*activations_.env, RowPtrFromBatch(activations_.att_sums));
|
*activations_.env,
|
||||||
|
RowPtrFromBatch(allocator_, activations_.att_sums));
|
||||||
}
|
}
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
|
@ -533,7 +537,8 @@ class GemmaAttention {
|
||||||
layer_weights_(*layer_weights),
|
layer_weights_(*layer_weights),
|
||||||
div_seq_len_(div_seq_len),
|
div_seq_len_(div_seq_len),
|
||||||
kv_caches_(kv_caches),
|
kv_caches_(kv_caches),
|
||||||
pool_(activations.env->parallel.Pools().Pool(0)) {
|
allocator_(activations.env->ctx.allocator),
|
||||||
|
pool_(activations.env->ctx.pools.Pool(0)) {
|
||||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||||
"query heads must be a multiple of key-value heads");
|
"query heads must be a multiple of key-value heads");
|
||||||
|
|
@ -562,6 +567,7 @@ class GemmaAttention {
|
||||||
const LayerWeightsPtrs<T>& layer_weights_;
|
const LayerWeightsPtrs<T>& layer_weights_;
|
||||||
const hwy::Divisor& div_seq_len_;
|
const hwy::Divisor& div_seq_len_;
|
||||||
const KVCaches& kv_caches_;
|
const KVCaches& kv_caches_;
|
||||||
|
const Allocator2& allocator_;
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -606,8 +612,8 @@ class VitAttention {
|
||||||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||||
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
||||||
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
||||||
layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env,
|
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
|
||||||
RowPtrFromBatch(qkv));
|
RowPtrFromBatch(allocator_, qkv));
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(philculliton): transition fully to MatMul.
|
// TODO(philculliton): transition fully to MatMul.
|
||||||
|
|
@ -621,10 +627,10 @@ class VitAttention {
|
||||||
|
|
||||||
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
|
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
|
||||||
RowVectorBatch<float> Q =
|
RowVectorBatch<float> Q =
|
||||||
AllocateAlignedRows<float>(Extents2D(num_tokens_, qkv_dim));
|
AllocateAlignedRows<float>(allocator_, Extents2D(num_tokens_, qkv_dim));
|
||||||
RowVectorBatch<float> K =
|
RowVectorBatch<float> K =
|
||||||
AllocateAlignedRows<float>(Extents2D(seq_len, qkv_dim));
|
AllocateAlignedRows<float>(allocator_, Extents2D(seq_len, qkv_dim));
|
||||||
RowVectorBatch<float> C(Extents2D(num_tokens_, seq_len));
|
RowVectorBatch<float> C(allocator_, Extents2D(num_tokens_, seq_len));
|
||||||
|
|
||||||
// Initialize att_out to zero prior to head loop.
|
// Initialize att_out to zero prior to head loop.
|
||||||
hwy::ZeroBytes(activations_.att_out.All(),
|
hwy::ZeroBytes(activations_.att_out.All(),
|
||||||
|
|
@ -650,7 +656,7 @@ class VitAttention {
|
||||||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||||
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
|
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
|
||||||
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
|
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
|
||||||
RowPtrFromBatch(C));
|
RowPtrFromBatch(allocator_, C));
|
||||||
|
|
||||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||||
float* HWY_RESTRICT c = C.Batch(task);
|
float* HWY_RESTRICT c = C.Batch(task);
|
||||||
|
|
@ -712,13 +718,13 @@ class VitAttention {
|
||||||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||||
HWY_NOINLINE void SumHeads() {
|
HWY_NOINLINE void SumHeads() {
|
||||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
auto* bias = layer_weights_.vit.attn_out_b.PackedScale1();
|
||||||
// att_weights and att_out are concatenated heads, each of length
|
// att_weights and att_out are concatenated heads, each of length
|
||||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||||
// matmul output is the sum over heads.
|
// matmul output is the sum over heads.
|
||||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||||
auto att_sums = RowPtrFromBatch(activations_.att_sums);
|
auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums);
|
||||||
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
|
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -730,7 +736,8 @@ class VitAttention {
|
||||||
activations_(activations),
|
activations_(activations),
|
||||||
layer_weights_(*layer_weights),
|
layer_weights_(*layer_weights),
|
||||||
layer_config_(layer_weights->layer_config),
|
layer_config_(layer_weights->layer_config),
|
||||||
pool_(activations.env->parallel.Pools().Pool(0)) {}
|
allocator_(activations.env->ctx.allocator),
|
||||||
|
pool_(activations.env->ctx.pools.Pool(0)) {}
|
||||||
|
|
||||||
HWY_INLINE void operator()() {
|
HWY_INLINE void operator()() {
|
||||||
ComputeQKV();
|
ComputeQKV();
|
||||||
|
|
@ -748,6 +755,7 @@ class VitAttention {
|
||||||
Activations& activations_;
|
Activations& activations_;
|
||||||
const LayerWeightsPtrs<T>& layer_weights_;
|
const LayerWeightsPtrs<T>& layer_weights_;
|
||||||
const LayerConfig& layer_config_;
|
const LayerConfig& layer_config_;
|
||||||
|
const Allocator2& allocator_;
|
||||||
hwy::ThreadPool& pool_;
|
hwy::ThreadPool& pool_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -779,32 +787,35 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
||||||
|
|
||||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||||
const float* bias1 =
|
const float* bias1 =
|
||||||
add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr;
|
add_bias ? layer_weights->ffw_gating_biases.PackedScale1() : nullptr;
|
||||||
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
||||||
const float* output_bias =
|
const float* output_bias =
|
||||||
add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr;
|
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||||
|
|
||||||
// Define slightly more readable names for the weights and activations.
|
// Define slightly more readable names for the weights and activations.
|
||||||
const auto x =
|
const auto x =
|
||||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||||
|
|
||||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
const Allocator2& allocator = activations.env->ctx.allocator;
|
||||||
auto multiplier = RowPtrFromBatch(activations.C2);
|
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
|
||||||
|
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||||
|
|
||||||
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
||||||
// avoid this confusion by splitting into gating_einsum_w1 and
|
// avoid this confusion by splitting into gating_einsum_w1 and
|
||||||
// gating_einsum_w2.
|
// gating_einsum_w2.
|
||||||
const bool split = !!layer_weights->gating_einsum_w.data();
|
const bool split = layer_weights->gating_einsum_w.HasPtr();
|
||||||
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
|
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
|
||||||
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
|
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
|
||||||
auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w,
|
decltype(w1) w2;
|
||||||
model_dim * ffh_hidden_dim)
|
|
||||||
: ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
|
||||||
if (split) {
|
if (split) {
|
||||||
|
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w);
|
||||||
|
w2.ofs = w2.Row(ffh_hidden_dim);
|
||||||
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
||||||
w1.ShrinkRows(ffh_hidden_dim);
|
w1.ShrinkRows(ffh_hidden_dim);
|
||||||
w2.ShrinkRows(ffh_hidden_dim);
|
w2.ShrinkRows(ffh_hidden_dim);
|
||||||
|
} else {
|
||||||
|
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
||||||
}
|
}
|
||||||
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
||||||
|
|
||||||
|
|
@ -835,16 +846,17 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||||
|
|
||||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||||
const float* bias1 =
|
const float* bias1 =
|
||||||
add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr;
|
add_bias ? layer_weights->vit.linear_0_b.PackedScale1() : nullptr;
|
||||||
const float* output_bias =
|
const float* output_bias =
|
||||||
add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr;
|
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
||||||
|
|
||||||
// Define slightly more readable names for the weights and activations.
|
// Define slightly more readable names for the weights and activations.
|
||||||
const auto x =
|
const auto x =
|
||||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||||
|
|
||||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
const Allocator2& allocator = activations.env->ctx.allocator;
|
||||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||||
|
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||||
|
|
||||||
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
|
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
|
||||||
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
||||||
|
|
@ -853,7 +865,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
||||||
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
||||||
|
|
||||||
// Activation (Gelu), store in act.
|
// Activation (Gelu), store in act.
|
||||||
RowPtrF multiplier = RowPtrF(nullptr, 0);
|
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
|
||||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||||
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
||||||
|
|
||||||
|
|
@ -905,11 +917,9 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
||||||
HWY_DASSERT(token < static_cast<int>(vocab_size));
|
HWY_DASSERT(token < static_cast<int>(vocab_size));
|
||||||
|
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
DecompressAndZeroPad(
|
DecompressAndZeroPad(df, weights.embedder_input_embedding.Span(),
|
||||||
df,
|
|
||||||
MakeSpan(weights.embedder_input_embedding.data(), vocab_size * model_dim),
|
|
||||||
token * model_dim, x.Batch(batch_idx), model_dim);
|
token * model_dim, x.Batch(batch_idx), model_dim);
|
||||||
MulByConst(emb_scaling * weights.embedder_input_embedding.scale(),
|
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
|
||||||
x.Batch(batch_idx), model_dim);
|
x.Batch(batch_idx), model_dim);
|
||||||
if (weights.weights_config.absolute_pe) {
|
if (weights.weights_config.absolute_pe) {
|
||||||
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos);
|
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos);
|
||||||
|
|
@ -943,9 +953,10 @@ HWY_NOINLINE void ResidualConnection(
|
||||||
template <typename WeightT, typename InOutT>
|
template <typename WeightT, typename InOutT>
|
||||||
void PostNorm(PostNormType post_norm, size_t num_interleaved,
|
void PostNorm(PostNormType post_norm, size_t num_interleaved,
|
||||||
const WeightT& weights, InOutT* inout) {
|
const WeightT& weights, InOutT* inout) {
|
||||||
|
HWY_DASSERT(weights.Rows() == 1);
|
||||||
if (post_norm == PostNormType::Scale) {
|
if (post_norm == PostNormType::Scale) {
|
||||||
RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout,
|
RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout,
|
||||||
weights.NumElements());
|
weights.Cols());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -962,7 +973,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
||||||
auto type = layer_weights->layer_config.type;
|
auto type = layer_weights->layer_config.type;
|
||||||
|
|
||||||
RMSNormBatched(num_interleaved, activations.x.All(),
|
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||||
layer_weights->pre_attention_norm_scale.data_scale1(),
|
layer_weights->pre_attention_norm_scale.PackedScale1(),
|
||||||
activations.pre_att_rms_out.All(), model_dim);
|
activations.pre_att_rms_out.All(), model_dim);
|
||||||
|
|
||||||
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
|
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
|
||||||
|
|
@ -976,7 +987,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
||||||
activations.x.All(), layer_weights, /*is_attention=*/true);
|
activations.x.All(), layer_weights, /*is_attention=*/true);
|
||||||
|
|
||||||
RMSNormBatched(num_interleaved, activations.x.All(),
|
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||||
layer_weights->pre_ffw_norm_scale.data_scale1(),
|
layer_weights->pre_ffw_norm_scale.PackedScale1(),
|
||||||
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||||
|
|
||||||
if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
|
if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
|
||||||
|
|
@ -1014,8 +1025,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
||||||
// y = nn.LayerNorm()(x)
|
// y = nn.LayerNorm()(x)
|
||||||
// y ~ pre_att_rms_out
|
// y ~ pre_att_rms_out
|
||||||
LayerNormBatched(num_tokens, x.All(),
|
LayerNormBatched(num_tokens, x.All(),
|
||||||
layer_weights->vit.layer_norm_0_scale.data_scale1(),
|
layer_weights->vit.layer_norm_0_scale.PackedScale1(),
|
||||||
layer_weights->vit.layer_norm_0_bias.data_scale1(),
|
layer_weights->vit.layer_norm_0_bias.PackedScale1(),
|
||||||
activations.pre_att_rms_out.All(), model_dim);
|
activations.pre_att_rms_out.All(), model_dim);
|
||||||
|
|
||||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||||
|
|
@ -1028,8 +1039,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
||||||
// y = nn.LayerNorm()(x)
|
// y = nn.LayerNorm()(x)
|
||||||
// y ~ bf_pre_ffw_rms_out
|
// y ~ bf_pre_ffw_rms_out
|
||||||
LayerNormBatched(num_tokens, x.All(),
|
LayerNormBatched(num_tokens, x.All(),
|
||||||
layer_weights->vit.layer_norm_1_scale.data_scale1(),
|
layer_weights->vit.layer_norm_1_scale.PackedScale1(),
|
||||||
layer_weights->vit.layer_norm_1_bias.data_scale1(),
|
layer_weights->vit.layer_norm_1_bias.PackedScale1(),
|
||||||
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||||
|
|
||||||
// y = out["mlp"] = MlpBlock(...)(y)
|
// y = out["mlp"] = MlpBlock(...)(y)
|
||||||
|
|
@ -1161,8 +1172,8 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
const size_t patch_width = weights.weights_config.vit_config.patch_width;
|
const size_t patch_width = weights.weights_config.vit_config.patch_width;
|
||||||
const size_t seq_len = weights.weights_config.vit_config.seq_len;
|
const size_t seq_len = weights.weights_config.vit_config.seq_len;
|
||||||
const size_t patch_size = patch_width * patch_width * 3;
|
const size_t patch_size = patch_width * patch_width * 3;
|
||||||
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||||
patch_size * model_dim);
|
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
|
|
@ -1178,20 +1189,20 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
||||||
// MatMul(
|
// MatMul(
|
||||||
// MatFromBatch(kVitSeqLen, image_patches),
|
// MatFromBatch(kVitSeqLen, image_patches),
|
||||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||||
// weights.vit_img_embedding_bias.data_scale1(), *activations.env,
|
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||||
// RowPtrF(activations.x.All(), kVitModelDim));
|
// RowPtrF(activations.x.All(), kVitModelDim));
|
||||||
// However, MatMul currently requires that
|
// However, MatMul currently requires that
|
||||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||||
// which is not the case here. We should relax that requirement on MatMul and
|
// which is not the case here. We should relax that requirement on MatMul and
|
||||||
// then use the above. For now, we rely on MatVecAdd instead.
|
// then use the above. For now, we rely on MatVecAdd instead.
|
||||||
for (size_t i = 0; i < seq_len; ++i) {
|
for (size_t i = 0; i < seq_len; ++i) {
|
||||||
MatVecAdd(
|
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||||
weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
image_patches[i].get(),
|
||||||
image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(),
|
weights.vit_img_embedding_bias.PackedScale1(),
|
||||||
activations.x.Batch(i), activations.env->parallel.Pools().Pool(0));
|
activations.x.Batch(i), activations.env->ctx.pools.Pool(0));
|
||||||
}
|
}
|
||||||
// Add position embeddings.
|
// Add position embeddings.
|
||||||
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
|
AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(),
|
||||||
seq_len * model_dim);
|
seq_len * model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1216,23 +1227,23 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
||||||
}
|
}
|
||||||
// Final Layernorm.
|
// Final Layernorm.
|
||||||
LayerNormBatched(num_tokens, activations.x.All(),
|
LayerNormBatched(num_tokens, activations.x.All(),
|
||||||
weights.vit_encoder_norm_scale.data_scale1(),
|
weights.vit_encoder_norm_scale.PackedScale1(),
|
||||||
weights.vit_encoder_norm_bias.data_scale1(),
|
weights.vit_encoder_norm_bias.PackedScale1(),
|
||||||
activations.x.All(), vit_model_dim);
|
activations.x.All(), vit_model_dim);
|
||||||
|
|
||||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||||
activations.x = AvgPool4x4(activations.x);
|
activations.x = AvgPool4x4(activations.x);
|
||||||
|
|
||||||
// Apply soft embedding norm before input projection.
|
// Apply soft embedding norm before input projection.
|
||||||
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
|
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(),
|
||||||
vit_model_dim);
|
vit_model_dim);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||||
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
||||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||||
weights.vit_img_head_bias.data_scale1(), *activations.env,
|
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||||
RowPtrFromBatch(image_tokens));
|
RowPtrFromBatch(activations.env->ctx.allocator, image_tokens));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generates one token for each query. `queries_token` is the previous token
|
// Generates one token for each query. `queries_token` is the previous token
|
||||||
|
|
@ -1274,7 +1285,7 @@ HWY_NOINLINE void Transformer(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
|
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(),
|
||||||
activations.x.All(), model_dim);
|
activations.x.All(), model_dim);
|
||||||
|
|
||||||
if (activations_observer) {
|
if (activations_observer) {
|
||||||
|
|
@ -1374,7 +1385,7 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
||||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||||
/*add=*/nullptr, *activations.env,
|
/*add=*/nullptr, *activations.env,
|
||||||
RowPtrFromBatch(activations.logits));
|
RowPtrFromBatch(activations.env->ctx.allocator, activations.logits));
|
||||||
}
|
}
|
||||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
|
|
|
||||||
|
|
@ -27,22 +27,33 @@
|
||||||
#include <utility> // std::move
|
#include <utility> // std::move
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/io.h" // Path
|
// Placeholder for internal header, do not modify.
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
|
#include "gemma/configs.h"
|
||||||
|
#include "gemma/tokenizer.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "ops/ops-inl.h"
|
#include "ops/matmul.h"
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/highway.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Internal init must run before I/O; calling it from `GemmaEnv()` is too late.
|
||||||
|
// This helper function takes care of the internal init plus calling `SetArgs`.
|
||||||
|
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
||||||
|
// Placeholder for internal init, do not modify.
|
||||||
|
|
||||||
|
ThreadingContext2::SetArgs(threading_args);
|
||||||
|
return MatMulEnv(ThreadingContext2::Get());
|
||||||
|
}
|
||||||
|
|
||||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
const ModelInfo& info, MatMulEnv& env)
|
const ModelInfo& info, MatMulEnv& env)
|
||||||
: env_(env), tokenizer_(tokenizer_path) {
|
: env_(env), tokenizer_(tokenizer_path) {
|
||||||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||||
env_.parallel.Pools().Pool(0),
|
env_.ctx.pools.Pool(0),
|
||||||
/*tokenizer_proto=*/nullptr);
|
/*tokenizer_proto=*/nullptr);
|
||||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||||
}
|
}
|
||||||
|
|
@ -50,7 +61,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||||
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
||||||
std::string tokenizer_proto;
|
std::string tokenizer_proto;
|
||||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||||
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
env_.ctx.pools.Pool(0), &tokenizer_proto);
|
||||||
tokenizer_.Deserialize(tokenizer_proto);
|
tokenizer_.Deserialize(tokenizer_proto);
|
||||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||||
}
|
}
|
||||||
|
|
@ -60,7 +71,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
||||||
tokenizer_(std::move(tokenizer)),
|
tokenizer_(std::move(tokenizer)),
|
||||||
chat_template_(tokenizer_, info.model) {
|
chat_template_(tokenizer_, info.model) {
|
||||||
HWY_ASSERT(info.weight == Type::kF32);
|
HWY_ASSERT(info.weight == Type::kF32);
|
||||||
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
model_.Allocate(info.model, info.weight, env_.ctx.pools.Pool(0));
|
||||||
}
|
}
|
||||||
|
|
||||||
Gemma::~Gemma() {
|
Gemma::~Gemma() {
|
||||||
|
|
@ -130,12 +141,12 @@ struct GenerateImageTokensT {
|
||||||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
model_.CallForModelWeight<GenerateSingleT>(
|
model_.CallForModelWeight<GenerateSingleT>(
|
||||||
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
|
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
|
||||||
|
|
||||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
|
|
@ -152,23 +163,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||||
}
|
}
|
||||||
|
|
||||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
model_.CallForModelWeight<GenerateBatchT>(
|
model_.CallForModelWeight<GenerateBatchT>(
|
||||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||||
kv_caches, &env_, timing_info);
|
kv_caches, &env_, timing_info);
|
||||||
|
|
||||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
||||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||||
const Image& image, ImageTokens& image_tokens) {
|
const Image& image, ImageTokens& image_tokens) {
|
||||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||||
|
|
||||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||||
image_tokens, &env_);
|
image_tokens, &env_);
|
||||||
|
|
||||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <functional>
|
#include <functional>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
@ -31,8 +33,9 @@
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/allocator.h" // RowVectorBatch
|
|
||||||
#include "util/basics.h" // TokenAndProb
|
#include "util/basics.h" // TokenAndProb
|
||||||
|
#include "util/mat.h" // RowVectorBatch
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
|
|
@ -193,6 +196,10 @@ struct TimingInfo {
|
||||||
size_t tokens_generated = 0;
|
size_t tokens_generated = 0;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Internal init must run before I/O; calling it from GemmaEnv() is too late.
|
||||||
|
// This helper function takes care of the internal init plus calling `SetArgs`.
|
||||||
|
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
||||||
|
|
||||||
class Gemma {
|
class Gemma {
|
||||||
public:
|
public:
|
||||||
// Reads old format weights file and tokenizer file.
|
// Reads old format weights file and tokenizer file.
|
||||||
|
|
@ -206,7 +213,9 @@ class Gemma {
|
||||||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
|
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
|
||||||
~Gemma();
|
~Gemma();
|
||||||
|
|
||||||
|
MatMulEnv& Env() const { return env_; }
|
||||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||||
|
// DEPRECATED
|
||||||
ModelInfo Info() const {
|
ModelInfo Info() const {
|
||||||
return ModelInfo({.model = model_.Config().model,
|
return ModelInfo({.model = model_.Config().model,
|
||||||
.wrapping = model_.Config().wrapping,
|
.wrapping = model_.Config().wrapping,
|
||||||
|
|
|
||||||
|
|
@ -15,8 +15,8 @@
|
||||||
|
|
||||||
// Shared between various frontends.
|
// Shared between various frontends.
|
||||||
|
|
||||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
@ -31,103 +31,10 @@
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/args.h"
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "util/threading.h"
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
#include "hwy/base.h" // HWY_IS_ASAN
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static inline const char* CompiledConfig() {
|
|
||||||
if (HWY_IS_ASAN) {
|
|
||||||
return "asan";
|
|
||||||
} else if (HWY_IS_MSAN) {
|
|
||||||
return "msan";
|
|
||||||
} else if (HWY_IS_TSAN) {
|
|
||||||
return "tsan";
|
|
||||||
} else if (HWY_IS_HWASAN) {
|
|
||||||
return "hwasan";
|
|
||||||
} else if (HWY_IS_UBSAN) {
|
|
||||||
return "ubsan";
|
|
||||||
} else if (HWY_IS_DEBUG_BUILD) {
|
|
||||||
return "dbg";
|
|
||||||
} else {
|
|
||||||
return "opt";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
class AppArgs : public ArgsBase<AppArgs> {
|
|
||||||
public:
|
|
||||||
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
|
||||||
AppArgs() { Init(); };
|
|
||||||
|
|
||||||
int verbosity;
|
|
||||||
|
|
||||||
size_t max_threads; // divided among the detected clusters
|
|
||||||
Tristate pin; // pin threads?
|
|
||||||
Tristate spin; // use spin waits?
|
|
||||||
|
|
||||||
// For BoundedSlice:
|
|
||||||
size_t skip_packages;
|
|
||||||
size_t max_packages;
|
|
||||||
size_t skip_clusters;
|
|
||||||
size_t max_clusters;
|
|
||||||
size_t skip_lps;
|
|
||||||
size_t max_lps;
|
|
||||||
|
|
||||||
std::string eot_line;
|
|
||||||
|
|
||||||
template <class Visitor>
|
|
||||||
void ForEach(const Visitor& visitor) {
|
|
||||||
visitor(verbosity, "verbosity", 1,
|
|
||||||
"Show verbose developer information\n 0 = only print generation "
|
|
||||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
|
||||||
"developer/debug info).\n Default = 1.",
|
|
||||||
2);
|
|
||||||
|
|
||||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
|
||||||
visitor(max_threads, "num_threads", size_t{0},
|
|
||||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(pin, "pin", Tristate::kDefault,
|
|
||||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
|
||||||
visitor(spin, "spin", Tristate::kDefault,
|
|
||||||
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
|
|
||||||
// These can be used to partition CPU sockets/packages and their
|
|
||||||
// clusters/CCXs across several program instances. The default is to use
|
|
||||||
// all available resources.
|
|
||||||
visitor(skip_packages, "skip_packages", size_t{0},
|
|
||||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(max_packages, "max_packages", size_t{0},
|
|
||||||
"Maximum number of sockets to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
|
||||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(max_clusters, "max_clusters", size_t{0},
|
|
||||||
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
|
|
||||||
// These are only used when CPU topology is unknown.
|
|
||||||
visitor(skip_lps, "skip_lps", size_t{0},
|
|
||||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
|
||||||
visitor(max_lps, "max_lps", size_t{0},
|
|
||||||
"Maximum number of LPs to use; default 0 = unlimited.", 2);
|
|
||||||
|
|
||||||
visitor(
|
|
||||||
eot_line, "eot_line", std::string(""),
|
|
||||||
"End of turn line. "
|
|
||||||
"When you specify this, the prompt will be all lines "
|
|
||||||
"before the line where only the given string appears.\n Default = "
|
|
||||||
"When a newline is encountered, that signals the end of the turn.",
|
|
||||||
2);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
static inline BoundedTopology CreateTopology(const AppArgs& app) {
|
|
||||||
return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages),
|
|
||||||
BoundedSlice(app.skip_clusters, app.max_clusters),
|
|
||||||
BoundedSlice(app.skip_lps, app.max_lps));
|
|
||||||
}
|
|
||||||
static inline NestedPools CreatePools(const BoundedTopology& topology,
|
|
||||||
const AppArgs& app) {
|
|
||||||
Allocator::Init(topology);
|
|
||||||
return NestedPools(topology, app.max_threads, app.pin);
|
|
||||||
}
|
|
||||||
|
|
||||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||||
InitAndParse(argc, argv);
|
InitAndParse(argc, argv);
|
||||||
|
|
@ -154,15 +61,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() {
|
const char* Validate() {
|
||||||
if (!compressed_weights.path.empty()) {
|
|
||||||
if (weights.path.empty()) {
|
|
||||||
weights = compressed_weights;
|
|
||||||
} else {
|
|
||||||
return "Only one of --weights and --compressed_weights can be "
|
|
||||||
"specified. To create compressed weights use the "
|
|
||||||
"compress_weights tool.";
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (weights.path.empty()) {
|
if (weights.path.empty()) {
|
||||||
return "Missing --weights flag, a file for the model weights.";
|
return "Missing --weights flag, a file for the model weights.";
|
||||||
}
|
}
|
||||||
|
|
@ -250,6 +148,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||||
InferenceArgs() { Init(); };
|
InferenceArgs() { Init(); };
|
||||||
|
|
||||||
|
int verbosity;
|
||||||
|
|
||||||
size_t max_generated_tokens;
|
size_t max_generated_tokens;
|
||||||
|
|
||||||
size_t prefill_tbatch_size;
|
size_t prefill_tbatch_size;
|
||||||
|
|
@ -261,6 +161,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
bool multiturn;
|
bool multiturn;
|
||||||
Path image_file;
|
Path image_file;
|
||||||
|
|
||||||
|
std::string eot_line;
|
||||||
|
|
||||||
// Returns error string or nullptr if OK.
|
// Returns error string or nullptr if OK.
|
||||||
const char* Validate() const {
|
const char* Validate() const {
|
||||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||||
|
|
@ -272,6 +174,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
template <class Visitor>
|
template <class Visitor>
|
||||||
void ForEach(const Visitor& visitor) {
|
void ForEach(const Visitor& visitor) {
|
||||||
|
visitor(verbosity, "verbosity", 1,
|
||||||
|
"Show verbose developer information\n 0 = only print generation "
|
||||||
|
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||||
|
"developer/debug info).\n Default = 1.",
|
||||||
|
2);
|
||||||
|
|
||||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||||
"Maximum number of tokens to generate.");
|
"Maximum number of tokens to generate.");
|
||||||
|
|
||||||
|
|
@ -291,6 +199,14 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
" Default : 0 (conversation "
|
" Default : 0 (conversation "
|
||||||
"resets every turn)");
|
"resets every turn)");
|
||||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||||
|
|
||||||
|
visitor(
|
||||||
|
eot_line, "eot_line", std::string(""),
|
||||||
|
"End of turn line. "
|
||||||
|
"When you specify this, the prompt will be all lines "
|
||||||
|
"before the line where only the given string appears.\n Default = "
|
||||||
|
"When a newline is encountered, that signals the end of the turn.",
|
||||||
|
2);
|
||||||
}
|
}
|
||||||
|
|
||||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||||
|
|
@ -317,4 +233,4 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||||
84
gemma/run.cc
84
gemma/run.cc
|
|
@ -15,23 +15,23 @@
|
||||||
|
|
||||||
// Command line text interface to gemma.
|
// Command line text interface to gemma.
|
||||||
|
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <string_view>
|
#include <string_view>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// Placeholder for internal header, do not modify.
|
|
||||||
#include "compression/shared.h" // PromptWrapping
|
#include "compression/shared.h" // PromptWrapping
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/gemma.h" // Gemma
|
#include "gemma/gemma.h" // Gemma
|
||||||
|
#include "gemma/gemma_args.h" // LoaderArgs
|
||||||
#include "ops/matmul.h" // MatMulEnv
|
#include "ops/matmul.h" // MatMulEnv
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/app.h"
|
|
||||||
#include "util/args.h" // HasHelp
|
#include "util/args.h" // HasHelp
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
|
@ -78,35 +78,37 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
||||||
}
|
}
|
||||||
|
|
||||||
// The main Read-Eval-Print Loop.
|
// The main Read-Eval-Print Loop.
|
||||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||||
const InferenceArgs& args, const AcceptFunc& accept_token,
|
Gemma& model, KVCache& kv_cache) {
|
||||||
std::string& eot_line) {
|
|
||||||
PROFILER_ZONE("Gen.misc");
|
PROFILER_ZONE("Gen.misc");
|
||||||
size_t abs_pos = 0; // across turns
|
size_t abs_pos = 0; // across turns
|
||||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||||
size_t prompt_size = 0;
|
size_t prompt_size = 0;
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
InitGenerator(args, gen);
|
InitGenerator(inference, gen);
|
||||||
|
|
||||||
const bool have_image = !args.image_file.path.empty();
|
const bool have_image = !inference.image_file.path.empty();
|
||||||
Image image;
|
Image image;
|
||||||
ImageTokens image_tokens;
|
ImageTokens image_tokens;
|
||||||
if (have_image) {
|
if (have_image) {
|
||||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||||
image_tokens = ImageTokens(Extents2D(
|
image_tokens =
|
||||||
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
|
ImageTokens(model.Env().ctx.allocator,
|
||||||
|
Extents2D(model.GetModelConfig().vit_config.seq_len /
|
||||||
|
(pool_dim * pool_dim),
|
||||||
model.GetModelConfig().model_dim));
|
model.GetModelConfig().model_dim));
|
||||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||||
image.Resize(image_size, image_size);
|
image.Resize(image_size, image_size);
|
||||||
RuntimeConfig runtime_config = {
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
|
.verbosity = inference.verbosity,
|
||||||
|
.use_spinning = threading.spin};
|
||||||
double image_tokens_start = hwy::platform::Now();
|
double image_tokens_start = hwy::platform::Now();
|
||||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||||
if (app.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||||
fprintf(stderr,
|
fprintf(stderr,
|
||||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||||
|
|
@ -121,12 +123,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||||
++tokens_generated_this_turn;
|
++tokens_generated_this_turn;
|
||||||
if (in_prompt) {
|
if (in_prompt) {
|
||||||
if (app.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::cerr << "." << std::flush;
|
std::cerr << "." << std::flush;
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
} else if (model.GetModelConfig().IsEOS(token)) {
|
} else if (model.GetModelConfig().IsEOS(token)) {
|
||||||
if (app.verbosity >= 2) {
|
if (inference.verbosity >= 2) {
|
||||||
std::cout << "\n[ End ]\n";
|
std::cout << "\n[ End ]\n";
|
||||||
}
|
}
|
||||||
return true;
|
return true;
|
||||||
|
|
@ -135,7 +137,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||||
if (first_response_token) {
|
if (first_response_token) {
|
||||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||||
if (app.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -147,7 +149,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
tokens_generated_this_turn = 0;
|
tokens_generated_this_turn = 0;
|
||||||
|
|
||||||
// Read prompt and handle special commands.
|
// Read prompt and handle special commands.
|
||||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
std::string prompt_string =
|
||||||
|
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||||
if (!std::cin) return;
|
if (!std::cin) return;
|
||||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||||
|
|
@ -163,13 +166,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set up runtime config.
|
// Set up runtime config.
|
||||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||||
RuntimeConfig runtime_config = {.gen = &gen,
|
RuntimeConfig runtime_config = {.gen = &gen,
|
||||||
.verbosity = app.verbosity,
|
.verbosity = inference.verbosity,
|
||||||
.stream_token = stream_token,
|
.stream_token = stream_token,
|
||||||
.accept_token = accept_token,
|
.use_spinning = threading.spin};
|
||||||
.use_spinning = app.spin};
|
inference.CopyTo(runtime_config);
|
||||||
args.CopyTo(runtime_config);
|
|
||||||
size_t prefix_end = 0;
|
size_t prefix_end = 0;
|
||||||
|
|
||||||
std::vector<int> prompt;
|
std::vector<int> prompt;
|
||||||
|
|
@ -197,7 +199,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate until EOS or max_generated_tokens.
|
// Generate until EOS or max_generated_tokens.
|
||||||
if (app.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||||
}
|
}
|
||||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||||
|
|
@ -205,9 +207,10 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
std::cout << "\n\n";
|
std::cout << "\n\n";
|
||||||
|
|
||||||
// Prepare for the next turn. Works only for PaliGemma.
|
// Prepare for the next turn. Works only for PaliGemma.
|
||||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
if (!inference.multiturn ||
|
||||||
|
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||||
abs_pos = 0; // Start a new turn at position 0.
|
abs_pos = 0; // Start a new turn at position 0.
|
||||||
InitGenerator(args, gen);
|
InitGenerator(inference, gen);
|
||||||
} else {
|
} else {
|
||||||
// The last token was either EOS, then it should be ignored because it is
|
// The last token was either EOS, then it should be ignored because it is
|
||||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||||
|
|
@ -223,20 +226,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
||||||
|
InferenceArgs& inference) {
|
||||||
PROFILER_ZONE("Run.misc");
|
PROFILER_ZONE("Run.misc");
|
||||||
|
|
||||||
// Note that num_threads is an upper bound; we also limit to the number of
|
// Note that num_threads is an upper bound; we also limit to the number of
|
||||||
// detected and enabled cores.
|
// detected and enabled cores.
|
||||||
const BoundedTopology topology = CreateTopology(app);
|
MatMulEnv env(MakeMatMulEnv(threading));
|
||||||
NestedPools pools = CreatePools(topology, app);
|
if (inference.verbosity >= 2) env.print_best = true;
|
||||||
MatMulEnv env(topology, pools);
|
|
||||||
if (app.verbosity >= 2) env.print_best = true;
|
|
||||||
Gemma model = CreateGemma(loader, env);
|
Gemma model = CreateGemma(loader, env);
|
||||||
KVCache kv_cache =
|
KVCache kv_cache =
|
||||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||||
|
|
||||||
if (app.verbosity >= 1) {
|
if (inference.verbosity >= 1) {
|
||||||
std::string instructions =
|
std::string instructions =
|
||||||
"*Usage*\n"
|
"*Usage*\n"
|
||||||
" Enter an instruction and press enter (%C resets conversation, "
|
" Enter an instruction and press enter (%C resets conversation, "
|
||||||
|
|
@ -259,11 +261,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||||
|
|
||||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||||
<< kAsciiArtBanner << "\n\n";
|
<< kAsciiArtBanner << "\n\n";
|
||||||
ShowConfig(loader, inference, app, topology, pools);
|
ShowConfig(threading, loader, inference);
|
||||||
std::cout << "\n" << instructions << "\n";
|
std::cout << "\n" << instructions << "\n";
|
||||||
}
|
}
|
||||||
|
|
||||||
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
|
ReplGemma(threading, inference, model, kv_cache);
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
@ -272,31 +274,29 @@ int main(int argc, char** argv) {
|
||||||
{
|
{
|
||||||
PROFILER_ZONE("Startup.misc");
|
PROFILER_ZONE("Startup.misc");
|
||||||
|
|
||||||
// Placeholder for internal init, do not modify.
|
gcpp::ThreadingArgs threading(argc, argv);
|
||||||
|
|
||||||
gcpp::LoaderArgs loader(argc, argv);
|
gcpp::LoaderArgs loader(argc, argv);
|
||||||
gcpp::InferenceArgs inference(argc, argv);
|
gcpp::InferenceArgs inference(argc, argv);
|
||||||
gcpp::AppArgs app(argc, argv);
|
|
||||||
|
|
||||||
if (gcpp::HasHelp(argc, argv)) {
|
if (gcpp::HasHelp(argc, argv)) {
|
||||||
std::cerr << gcpp::kAsciiArtBanner;
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
gcpp::ShowHelp(loader, inference, app);
|
gcpp::ShowHelp(threading, loader, inference);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const char* error = loader.Validate()) {
|
if (const char* error = loader.Validate()) {
|
||||||
std::cerr << gcpp::kAsciiArtBanner;
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
gcpp::ShowHelp(loader, inference, app);
|
gcpp::ShowHelp(threading, loader, inference);
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (const char* error = inference.Validate()) {
|
if (const char* error = inference.Validate()) {
|
||||||
std::cerr << gcpp::kAsciiArtBanner;
|
std::cerr << gcpp::kAsciiArtBanner;
|
||||||
gcpp::ShowHelp(loader, inference, app);
|
gcpp::ShowHelp(threading, loader, inference);
|
||||||
HWY_ABORT("\nInvalid args: %s", error);
|
HWY_ABORT("\nInvalid args: %s", error);
|
||||||
}
|
}
|
||||||
|
|
||||||
gcpp::Run(loader, inference, app);
|
gcpp::Run(threading, loader, inference);
|
||||||
}
|
}
|
||||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||||
return 0;
|
return 0;
|
||||||
|
|
|
||||||
|
|
@ -562,11 +562,12 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
||||||
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
||||||
tensors_ = ModelTensors(config);
|
tensors_ = ModelTensors(config);
|
||||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||||
img_layer_idx < config.vit_config.layer_configs.size()) {
|
img_layer_idx <
|
||||||
|
static_cast<int>(config.vit_config.layer_configs.size())) {
|
||||||
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||||
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
||||||
} else if (0 <= llm_layer_idx &&
|
} else if (0 <= llm_layer_idx &&
|
||||||
llm_layer_idx < config.layer_configs.size()) {
|
llm_layer_idx < static_cast<int>(config.layer_configs.size())) {
|
||||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||||
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
|
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,6 +54,28 @@ struct TensorInfo {
|
||||||
bool cols_take_extra_dims = false;
|
bool cols_take_extra_dims = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for
|
||||||
|
// not-present tensors such as ViT in a text-only model.
|
||||||
|
static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
|
||||||
|
if (tensor == nullptr) return Extents2D(0, 0);
|
||||||
|
|
||||||
|
size_t cols = tensor->shape.back();
|
||||||
|
size_t rows = 1;
|
||||||
|
if (tensor->cols_take_extra_dims) {
|
||||||
|
rows = tensor->shape[0];
|
||||||
|
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
||||||
|
cols *= tensor->shape[i];
|
||||||
|
}
|
||||||
|
} else { // rows take extra dims
|
||||||
|
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
||||||
|
rows *= tensor->shape[i];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Sometimes only one of rows or cols is zero; set both for consistency.
|
||||||
|
if (rows == 0 || cols == 0) rows = cols = 0;
|
||||||
|
return Extents2D(rows, cols);
|
||||||
|
}
|
||||||
|
|
||||||
// Universal index of tensor information, which can be built for a specific
|
// Universal index of tensor information, which can be built for a specific
|
||||||
// layer_idx.
|
// layer_idx.
|
||||||
class TensorIndex {
|
class TensorIndex {
|
||||||
|
|
@ -96,6 +118,16 @@ class TensorIndex {
|
||||||
std::unordered_map<std::string, size_t> name_map_;
|
std::unordered_map<std::string, size_t> name_map_;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
static inline TensorIndex TensorIndexLLM(const ModelConfig& config,
|
||||||
|
size_t llm_layer_idx) {
|
||||||
|
return TensorIndex(config, static_cast<int>(llm_layer_idx), -1, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
static inline TensorIndex TensorIndexImg(const ModelConfig& config,
|
||||||
|
size_t img_layer_idx) {
|
||||||
|
return TensorIndex(config, -1, static_cast<int>(img_layer_idx), false);
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,7 @@
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -118,7 +119,7 @@ struct TensorSaver {
|
||||||
weights.ForEachTensor(
|
weights.ForEachTensor(
|
||||||
{&weights}, fet,
|
{&weights}, fet,
|
||||||
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
|
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
tensors[0]->CallUpcasted(writer, name);
|
CallUpcasted(tensors[0]->GetType(), tensors[0], writer, name);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -155,11 +156,11 @@ class WeightInitializer {
|
||||||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||||
|
|
||||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
float* data = tensors[0]->data<float>();
|
float* data = tensors[0]->RowT<float>(0);
|
||||||
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
|
for (size_t i = 0; i < tensors[0]->Extents().Area(); ++i) {
|
||||||
data[i] = dist_(gen_);
|
data[i] = dist_(gen_);
|
||||||
}
|
}
|
||||||
tensors[0]->set_scale(1.0f);
|
tensors[0]->SetScale(1.0f);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
|
@ -226,11 +227,11 @@ void ModelWeightsStorage::LogWeightStats() {
|
||||||
{float_weights_.get()}, ForEachType::kInitNoToc,
|
{float_weights_.get()}, ForEachType::kInitNoToc,
|
||||||
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
|
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
const MatPtr& tensor = *tensors[0];
|
const MatPtr& tensor = *tensors[0];
|
||||||
if (tensor.scale() != 1.0f) {
|
if (tensor.Scale() != 1.0f) {
|
||||||
printf("[scale=%f] ", tensor.scale());
|
printf("[scale=%f] ", tensor.Scale());
|
||||||
}
|
}
|
||||||
LogVec(name, tensor.data<float>(), tensor.NumElements());
|
LogVec(name, tensor.RowT<float>(0), tensor.Extents().Area());
|
||||||
total_weights += tensor.NumElements();
|
total_weights += tensor.Extents().Area();
|
||||||
});
|
});
|
||||||
printf("%-20s %12zu\n", "Total", total_weights);
|
printf("%-20s %12zu\n", "Total", total_weights);
|
||||||
}
|
}
|
||||||
|
|
@ -258,8 +259,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <>
|
template <>
|
||||||
void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
void LayerWeightsPtrs<NuqStream>::Reshape(MatOwner* storage) {
|
||||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
const size_t model_dim = layer_config.model_dim;
|
const size_t model_dim = layer_config.model_dim;
|
||||||
const size_t heads = layer_config.heads;
|
const size_t heads = layer_config.heads;
|
||||||
|
|
@ -267,8 +268,7 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
||||||
|
|
||||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||||
if (storage != nullptr) {
|
if (storage != nullptr) {
|
||||||
storage->Allocate();
|
storage->AllocateFor(att_weights, MatPadding::kPacked);
|
||||||
att_weights.SetPtr(*storage);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||||
|
|
@ -279,7 +279,7 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
||||||
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||||
|
|
||||||
HWY_NAMESPACE::DecompressAndZeroPad(
|
HWY_NAMESPACE::DecompressAndZeroPad(
|
||||||
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0,
|
df, MakeSpan(attn_vec_einsum_w.Packed(), model_dim * heads * qkv_dim), 0,
|
||||||
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
|
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
|
||||||
|
|
||||||
for (size_t m = 0; m < model_dim; ++m) {
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
|
|
@ -296,10 +296,10 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
||||||
|
|
||||||
HWY_NAMESPACE::Compress(
|
HWY_NAMESPACE::Compress(
|
||||||
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
|
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
|
||||||
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
|
MakeSpan(att_weights.Packed(), model_dim * heads * qkv_dim),
|
||||||
/*packed_ofs=*/0, pool);
|
/*packed_ofs=*/0, pool);
|
||||||
|
|
||||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -31,12 +31,32 @@
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/tensor_index.h"
|
#include "gemma/tensor_index.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
static inline std::string CacheName(const MatPtr& mat, int layer = -1,
|
||||||
|
char separator = ' ', int index = -1) {
|
||||||
|
// Already used/retired: s, S, n, 1
|
||||||
|
const char prefix = mat.GetType() == Type::kF32 ? 'F'
|
||||||
|
: mat.GetType() == Type::kBF16 ? 'B'
|
||||||
|
: mat.GetType() == Type::kSFP ? '$'
|
||||||
|
: mat.GetType() == Type::kNUQ ? '2'
|
||||||
|
: '?';
|
||||||
|
std::string name = std::string(1, prefix) + mat.Name();
|
||||||
|
if (layer >= 0 || index >= 0) {
|
||||||
|
name += '_';
|
||||||
|
if (layer >= 0) name += std::to_string(layer);
|
||||||
|
if (index >= 0) {
|
||||||
|
name += separator + std::to_string(index);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return name;
|
||||||
|
}
|
||||||
|
|
||||||
// Different tensors need to appear in a ForEachTensor, according to what is
|
// Different tensors need to appear in a ForEachTensor, according to what is
|
||||||
// happening.
|
// happening.
|
||||||
enum class ForEachType {
|
enum class ForEachType {
|
||||||
|
|
@ -181,10 +201,10 @@ struct LayerWeightsPtrs {
|
||||||
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
||||||
// after loading weights via ForEachTensor.
|
// after loading weights via ForEachTensor.
|
||||||
// TODO: update compression/convert_weights to bake this in.
|
// TODO: update compression/convert_weights to bake this in.
|
||||||
void Reshape(MatStorage* storage) {
|
void Reshape(MatOwner* storage) {
|
||||||
static_assert(!hwy::IsSame<Weight, NuqStream>());
|
static_assert(!hwy::IsSame<Weight, NuqStream>());
|
||||||
|
|
||||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||||
|
|
||||||
const size_t model_dim = layer_config.model_dim;
|
const size_t model_dim = layer_config.model_dim;
|
||||||
const size_t heads = layer_config.heads;
|
const size_t heads = layer_config.heads;
|
||||||
|
|
@ -192,18 +212,18 @@ struct LayerWeightsPtrs {
|
||||||
|
|
||||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||||
if (storage != nullptr) {
|
if (storage != nullptr) {
|
||||||
storage->Allocate();
|
storage->AllocateFor(att_weights, MatPadding::kPacked);
|
||||||
att_weights.SetPtr(*storage);
|
|
||||||
}
|
}
|
||||||
for (size_t m = 0; m < model_dim; ++m) {
|
for (size_t m = 0; m < model_dim; ++m) {
|
||||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
Weight* HWY_RESTRICT out_row =
|
||||||
|
att_weights.template RowT<Weight>(0) + m * heads * qkv_dim;
|
||||||
for (size_t h = 0; h < heads; ++h) {
|
for (size_t h = 0; h < heads; ++h) {
|
||||||
hwy::CopyBytes(
|
hwy::CopyBytes(attn_vec_einsum_w.template RowT<Weight>(0) +
|
||||||
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
h * model_dim * qkv_dim + m * qkv_dim,
|
||||||
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayT<WeightF32OrBF16> key_norm_scale;
|
ArrayT<WeightF32OrBF16> key_norm_scale;
|
||||||
|
|
@ -215,8 +235,8 @@ struct LayerWeightsPtrs {
|
||||||
for (int i = 0; i < ptrs.size(); ++i) { \
|
for (int i = 0; i < ptrs.size(); ++i) { \
|
||||||
tensors[i] = &ptrs[i]->member; \
|
tensors[i] = &ptrs[i]->member; \
|
||||||
} \
|
} \
|
||||||
if (tensors[0]->Ptr() != nullptr || fet != ForEachType::kIgnoreNulls) { \
|
if (tensors[0]->HasPtr() || fet != ForEachType::kIgnoreNulls) { \
|
||||||
func(ptrs[0]->member.CacheName(layer_idx, sep, sep_index).c_str(), \
|
func(CacheName(ptrs[0]->member, layer_idx, sep, sep_index).c_str(), \
|
||||||
hwy::Span<MatPtr*>(tensors.data(), ptrs.size())); \
|
hwy::Span<MatPtr*>(tensors.data(), ptrs.size())); \
|
||||||
} \
|
} \
|
||||||
}
|
}
|
||||||
|
|
@ -307,19 +327,18 @@ struct LayerWeightsPtrs {
|
||||||
void ZeroInit(int layer_idx) {
|
void ZeroInit(int layer_idx) {
|
||||||
ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls,
|
ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls,
|
||||||
[](const char*, hwy::Span<MatPtr*> tensors) {
|
[](const char*, hwy::Span<MatPtr*> tensors) {
|
||||||
tensors[0]->ZeroInit();
|
gcpp::ZeroInit(*tensors[0]);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allocates memory for all the tensors in the layer.
|
// Allocates memory for all the tensors in the layer.
|
||||||
// Note that this is slow and only used for a stand-alone layer.
|
// Note that this is slow and only used for a stand-alone layer.
|
||||||
void Allocate(std::vector<MatStorage>& layer_storage) {
|
void Allocate(std::vector<MatOwner>& layer_storage) {
|
||||||
ForEachTensor(
|
ForEachTensor(
|
||||||
{this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
|
{this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
|
||||||
[&layer_storage](const char* name, hwy::Span<MatPtr*> tensors) {
|
[&layer_storage](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||||
layer_storage.emplace_back(*tensors[0]);
|
layer_storage.push_back(MatOwner());
|
||||||
layer_storage.back().Allocate();
|
layer_storage.back().AllocateFor(*tensors[0], MatPadding::kPacked);
|
||||||
tensors[0]->SetPtr(layer_storage.back());
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -393,11 +412,9 @@ struct ModelWeightsPtrs {
|
||||||
|
|
||||||
// Called by weights.cc after Loading, before att_w has been allocated.
|
// Called by weights.cc after Loading, before att_w has been allocated.
|
||||||
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool,
|
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool,
|
||||||
std::vector<MatStorage>& model_storage) {
|
std::vector<MatOwner>& model_storage) {
|
||||||
size_t storage_index = model_storage.size();
|
size_t storage_index = model_storage.size();
|
||||||
for (auto& layer : c_layers) {
|
model_storage.resize(model_storage.size() + c_layers.size());
|
||||||
model_storage.emplace_back(layer.att_weights);
|
|
||||||
}
|
|
||||||
pool.Run(0, c_layers.size(),
|
pool.Run(0, c_layers.size(),
|
||||||
[this, &model_storage, storage_index](uint64_t layer,
|
[this, &model_storage, storage_index](uint64_t layer,
|
||||||
size_t /*thread*/) {
|
size_t /*thread*/) {
|
||||||
|
|
@ -412,8 +429,8 @@ struct ModelWeightsPtrs {
|
||||||
}
|
}
|
||||||
|
|
||||||
void ZeroInit() {
|
void ZeroInit() {
|
||||||
embedder_input_embedding.ZeroInit();
|
gcpp::ZeroInit(embedder_input_embedding);
|
||||||
final_norm_scale.ZeroInit();
|
gcpp::ZeroInit(final_norm_scale);
|
||||||
for (size_t i = 0; i < c_layers.size(); ++i) {
|
for (size_t i = 0; i < c_layers.size(); ++i) {
|
||||||
c_layers[i].ZeroInit(i);
|
c_layers[i].ZeroInit(i);
|
||||||
}
|
}
|
||||||
|
|
@ -430,21 +447,21 @@ struct ModelWeightsPtrs {
|
||||||
return &vit_layers[layer];
|
return &vit_layers[layer];
|
||||||
}
|
}
|
||||||
|
|
||||||
void Allocate(std::vector<MatStorage>& model_storage, hwy::ThreadPool& pool) {
|
void Allocate(std::vector<MatOwner>& model_storage, hwy::ThreadPool& pool) {
|
||||||
std::vector<MatPtr*> model_toc;
|
std::vector<MatPtr*> model_toc;
|
||||||
ForEachTensor(
|
ForEachTensor(
|
||||||
{this}, ForEachType::kInitNoToc,
|
{this}, ForEachType::kInitNoToc,
|
||||||
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
|
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
|
||||||
model_toc.push_back(tensors[0]);
|
model_toc.push_back(tensors[0]);
|
||||||
model_storage.emplace_back(*tensors[0]);
|
model_storage.push_back(MatOwner());
|
||||||
});
|
});
|
||||||
// Allocate in parallel using the pool.
|
// Allocate in parallel using the pool.
|
||||||
pool.Run(0, model_toc.size(),
|
pool.Run(0, model_toc.size(),
|
||||||
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
|
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
|
||||||
// model_storage may have had content before we started.
|
// model_storage may have had content before we started.
|
||||||
size_t idx = task + model_storage.size() - model_toc.size();
|
size_t idx = task + model_storage.size() - model_toc.size();
|
||||||
model_storage[idx].Allocate();
|
model_storage[idx].AllocateFor(*model_toc[task],
|
||||||
model_toc[task]->SetPtr(model_storage[idx]);
|
MatPadding::kPacked);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -453,8 +470,7 @@ struct ModelWeightsPtrs {
|
||||||
ForEachTensor({this, const_cast<ModelWeightsPtrs<Weight>*>(&other)},
|
ForEachTensor({this, const_cast<ModelWeightsPtrs<Weight>*>(&other)},
|
||||||
ForEachType::kIgnoreNulls,
|
ForEachType::kIgnoreNulls,
|
||||||
[](const char*, hwy::Span<MatPtr*> tensors) {
|
[](const char*, hwy::Span<MatPtr*> tensors) {
|
||||||
hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(),
|
CopyMat(*tensors[1], *tensors[0]);
|
||||||
tensors[1]->SizeBytes());
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -467,10 +483,10 @@ struct ModelWeightsPtrs {
|
||||||
[&scales, &scale_pos, this](const char*, hwy::Span<MatPtr*> tensors) {
|
[&scales, &scale_pos, this](const char*, hwy::Span<MatPtr*> tensors) {
|
||||||
if (this->scale_names.count(tensors[0]->Name())) {
|
if (this->scale_names.count(tensors[0]->Name())) {
|
||||||
if (scale_pos < scales.size()) {
|
if (scale_pos < scales.size()) {
|
||||||
tensors[0]->set_scale(scales[scale_pos]);
|
tensors[0]->SetScale(scales[scale_pos]);
|
||||||
} else {
|
} else {
|
||||||
float scale = ScaleWeights(tensors[0]->data<float>(),
|
float scale = ScaleWeights(tensors[0]->RowT<float>(0),
|
||||||
tensors[0]->NumElements());
|
tensors[0]->Extents().Area());
|
||||||
scales.push_back(scale);
|
scales.push_back(scale);
|
||||||
}
|
}
|
||||||
++scale_pos;
|
++scale_pos;
|
||||||
|
|
@ -615,7 +631,7 @@ class ModelWeightsStorage {
|
||||||
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
||||||
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
||||||
// Storage for all the matrices and vectors.
|
// Storage for all the matrices and vectors.
|
||||||
std::vector<MatStorage> model_storage_;
|
std::vector<MatOwner> model_storage_;
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
|
|
@ -31,15 +31,12 @@
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/nanobenchmark.h"
|
#include "hwy/nanobenchmark.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -53,8 +50,8 @@
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
|
#include "compression/test_util-inl.h"
|
||||||
#include "ops/matmul-inl.h"
|
#include "ops/matmul-inl.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -63,59 +60,6 @@ extern int64_t first_target;
|
||||||
|
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
|
|
||||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
|
||||||
|
|
||||||
template <typename MatT>
|
|
||||||
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
|
||||||
|
|
||||||
// Generates inputs: deterministic, within max SfpStream range.
|
|
||||||
template <typename MatT>
|
|
||||||
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
gcpp::CompressWorkingSet ws;
|
|
||||||
auto mat =
|
|
||||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
|
||||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
|
||||||
HWY_ASSERT(content);
|
|
||||||
const float scale =
|
|
||||||
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
|
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
|
||||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
|
||||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
|
||||||
content[r * extents.cols + c] = f;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
|
||||||
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
// extents describes the transposed matrix.
|
|
||||||
template <typename MatT>
|
|
||||||
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
gcpp::CompressWorkingSet ws;
|
|
||||||
auto mat =
|
|
||||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
|
||||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
|
||||||
const float scale =
|
|
||||||
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
|
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
|
||||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
|
||||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
|
||||||
content[r * extents.cols + c] = f;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
|
||||||
// Arbitrary value, different from 1, must match GenerateMat.
|
|
||||||
mat->set_scale(0.6f);
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||||
std::vector<double>& times, MMPerKey* per_key) {
|
std::vector<double>& times, MMPerKey* per_key) {
|
||||||
std::sort(times.begin(), times.end());
|
std::sort(times.begin(), times.end());
|
||||||
|
|
@ -135,7 +79,8 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||||
// M = A rows, K = A cols, N = C cols.
|
// M = A rows, K = A cols, N = C cols.
|
||||||
template <typename TA, typename TB = TA, typename TC = float>
|
template <typename TA, typename TB = TA, typename TC = float>
|
||||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
const Allocator2& allocator = env.ctx.allocator;
|
||||||
|
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||||
if (env.print_config || env.print_measurement) {
|
if (env.print_config || env.print_measurement) {
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
}
|
}
|
||||||
|
|
@ -147,24 +92,23 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
const Extents2D B_extents(N, K); // already transposed
|
const Extents2D B_extents(N, K); // already transposed
|
||||||
const Extents2D C_extents(M, N);
|
const Extents2D C_extents(M, N);
|
||||||
|
|
||||||
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
|
RowVectorBatch<TC> c_slow_batch =
|
||||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
|
AllocateAlignedRows<TC>(allocator, C_extents);
|
||||||
|
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
|
||||||
if (add) {
|
if (add) {
|
||||||
add_storage = GenerateMat<float>(Extents2D(1, N), pool);
|
add_storage = GenerateMat<float>(Extents2D(1, N), pool);
|
||||||
HWY_ASSERT(add_storage);
|
add_storage.SetScale(1.0f);
|
||||||
add_storage->set_scale(1.0f);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||||
HWY_ASSERT(a && b_trans);
|
const auto A = ConstMatFromWeights(a);
|
||||||
const auto A = ConstMatFromWeights(*a);
|
const auto B = ConstMatFromWeights(b_trans);
|
||||||
const auto B = ConstMatFromWeights(*b_trans);
|
|
||||||
|
|
||||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||||
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
|
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
|
||||||
|
|
||||||
// Fewer reps for large batch sizes, which take longer.
|
// Fewer reps for large batch sizes, which take longer.
|
||||||
const size_t num_samples = M < 32 ? 20 : 12;
|
const size_t num_samples = M < 32 ? 20 : 12;
|
||||||
|
|
@ -174,11 +118,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
// Ensure usage conditions are set before autotuning. Both binding and
|
// Ensure usage conditions are set before autotuning. Both binding and
|
||||||
// spinning may materially affect the choice of config. No harm in calling
|
// spinning may materially affect the choice of config. No harm in calling
|
||||||
// BindB/C if there is a single package: they will be a no-op.
|
// BindB/C if there is a single package: they will be a no-op.
|
||||||
BindB(B_extents.rows, sizeof(TC), B, env.parallel);
|
BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel);
|
||||||
BindC(A_extents.rows, C, env.parallel);
|
BindC(allocator, A_extents.rows, C, env.parallel);
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
Tristate use_spinning = Tristate::kDefault;
|
||||||
env.parallel.Pools().MaybeStartSpinning(use_spinning);
|
env.ctx.pools.MaybeStartSpinning(use_spinning);
|
||||||
|
|
||||||
// env.print_config = true;
|
// env.print_config = true;
|
||||||
// env.print_measurement = true;
|
// env.print_measurement = true;
|
||||||
|
|
@ -198,7 +142,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||||
if (per_key->autotune.Best()) times.push_back(elapsed);
|
if (per_key->autotune.Best()) times.push_back(elapsed);
|
||||||
}
|
}
|
||||||
hwy::PreventElision(keep);
|
hwy::PreventElision(keep);
|
||||||
env.parallel.Pools().MaybeStopSpinning(use_spinning);
|
env.ctx.pools.MaybeStopSpinning(use_spinning);
|
||||||
PrintSpeed(A_extents, B_extents, times, per_key);
|
PrintSpeed(A_extents, B_extents, times, per_key);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -216,17 +160,11 @@ void BenchAllMatMul() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const size_t max_threads = 0; // no limit
|
ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||||
const BoundedSlice package_slice; // all packages/sockets
|
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
|
||||||
const BoundedSlice cluster_slice; // all clusters/CCX
|
ctx.pools.PinString());
|
||||||
const BoundedSlice lp_slice; // default to all cores (per package).
|
|
||||||
const BoundedTopology topology(package_slice, cluster_slice, lp_slice);
|
|
||||||
Allocator::Init(topology, /*enable_bind=*/true);
|
|
||||||
NestedPools pools(topology, max_threads, Tristate::kDefault);
|
|
||||||
fprintf(stderr, "BenchAllMatMul %s %s\n", topology.TopologyString(),
|
|
||||||
pools.PinString());
|
|
||||||
|
|
||||||
MatMulEnv env(topology, pools);
|
MatMulEnv env(ctx);
|
||||||
|
|
||||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
for (size_t batch_size : {1, 4, 128, 512}) {
|
||||||
constexpr bool kAdd = false;
|
constexpr bool kAdd = false;
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,7 @@
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
|
@ -379,10 +380,7 @@ template <typename MatT, typename VT>
|
||||||
HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs,
|
HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs,
|
||||||
const VT* vec_aligned, size_t num) {
|
const VT* vec_aligned, size_t num) {
|
||||||
const hn::ScalableTag<VT> d;
|
const hn::ScalableTag<VT> d;
|
||||||
return w.scale() * Dot(d,
|
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
||||||
MakeConstSpan(reinterpret_cast<const MatT*>(w.Ptr()),
|
|
||||||
w.NumElements()),
|
|
||||||
w_ofs, vec_aligned, num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -28,9 +28,8 @@
|
||||||
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/app.h"
|
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
|
|
@ -805,7 +804,7 @@ class DotStats {
|
||||||
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
|
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
|
||||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f);
|
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f);
|
||||||
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
|
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
|
||||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4);
|
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.8E-4);
|
||||||
|
|
||||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_l1s[kPairwise].Mean(), 4E-3);
|
ASSERT_INSIDE(kPairwise, 4.5E-4, s_l1s[kPairwise].Mean(), 4E-3);
|
||||||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f);
|
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f);
|
||||||
|
|
@ -1000,9 +999,7 @@ struct TestShortDotsT {
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
const hn::ScalableTag<float> df; // for CallDot
|
const hn::ScalableTag<float> df; // for CallDot
|
||||||
|
|
||||||
const AppArgs app;
|
const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator;
|
||||||
BoundedTopology topology(CreateTopology(app));
|
|
||||||
NestedPools pools = CreatePools(topology, app);
|
|
||||||
CompressWorkingSet work;
|
CompressWorkingSet work;
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
rng.seed(12345);
|
rng.seed(12345);
|
||||||
|
|
@ -1014,14 +1011,14 @@ struct TestShortDotsT {
|
||||||
// hence they require padding to one vector.
|
// hence they require padding to one vector.
|
||||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||||
RowVectorBatch<float> raw_w(Extents2D(1, padded_num));
|
RowVectorBatch<float> raw_w(allocator, Extents2D(1, padded_num));
|
||||||
RowVectorBatch<float> raw_v(Extents2D(1, padded_num));
|
RowVectorBatch<float> raw_v(allocator, Extents2D(1, padded_num));
|
||||||
RowVectorBatch<Packed> weights(Extents2D(1, packed_num));
|
RowVectorBatch<Packed> weights(allocator, Extents2D(1, packed_num));
|
||||||
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
||||||
RowVectorBatch<T> vectors(Extents2D(1, num));
|
RowVectorBatch<T> vectors(allocator, Extents2D(1, num));
|
||||||
const PackedSpan<T> v(vectors.Batch(0), num);
|
const PackedSpan<T> v(vectors.Batch(0), num);
|
||||||
|
|
||||||
RowVectorBatch<double> bufs(Extents2D(1, num));
|
RowVectorBatch<double> bufs(allocator, Extents2D(1, num));
|
||||||
double* HWY_RESTRICT buf = bufs.Batch(0);
|
double* HWY_RESTRICT buf = bufs.Batch(0);
|
||||||
|
|
||||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||||
|
|
@ -1099,10 +1096,21 @@ void TestAllDot() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
constexpr size_t kMaxWorkers = 15;
|
||||||
|
|
||||||
|
// Reset with cap on workers because we only support `kMaxWorkers`.
|
||||||
|
ThreadingContext2::ThreadHostileInvalidate();
|
||||||
|
ThreadingArgs threading_args;
|
||||||
|
threading_args.max_packages = 1;
|
||||||
|
threading_args.max_clusters = 1;
|
||||||
|
threading_args.max_lps = kMaxWorkers - 1;
|
||||||
|
ThreadingContext2::SetArgs(threading_args);
|
||||||
|
ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||||
|
const Allocator2& allocator = ctx.allocator;
|
||||||
|
|
||||||
{ // ensure no profiler zones are active
|
{ // ensure no profiler zones are active
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
||||||
constexpr size_t kMaxWorkers = 15;
|
|
||||||
std::mt19937 rngs[kMaxWorkers];
|
std::mt19937 rngs[kMaxWorkers];
|
||||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||||
rngs[i].seed(12345 + 65537 * i);
|
rngs[i].seed(12345 + 65537 * i);
|
||||||
|
|
@ -1110,15 +1118,13 @@ void TestAllDot() {
|
||||||
|
|
||||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||||
const size_t num = 24 * 1024;
|
const size_t num = 24 * 1024;
|
||||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1),
|
RowVectorBatch<float> a(allocator, Extents2D(kMaxWorkers, num));
|
||||||
BoundedSlice());
|
RowVectorBatch<float> b(allocator, Extents2D(kMaxWorkers, num));
|
||||||
NestedPools pools(topology, kMaxWorkers - 1, /*pin=*/Tristate::kDefault);
|
RowVectorBatch<double> bufs(allocator, Extents2D(kMaxWorkers, num));
|
||||||
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
|
|
||||||
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
|
|
||||||
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
|
|
||||||
std::array<DotStats, kMaxWorkers> all_stats;
|
std::array<DotStats, kMaxWorkers> all_stats;
|
||||||
|
|
||||||
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
ctx.pools.Cluster(0, 0).Run(
|
||||||
|
0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||||
float* HWY_RESTRICT pa = a.Batch(thread);
|
float* HWY_RESTRICT pa = a.Batch(thread);
|
||||||
float* HWY_RESTRICT pb = b.Batch(thread);
|
float* HWY_RESTRICT pb = b.Batch(thread);
|
||||||
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
||||||
|
|
@ -1136,7 +1142,8 @@ void TestAllDot() {
|
||||||
std::array<double, kTimeReps> elapsed;
|
std::array<double, kTimeReps> elapsed;
|
||||||
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||||
const double start = hwy::platform::Now();
|
const double start = hwy::platform::Now();
|
||||||
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
dots[variant] +=
|
||||||
|
CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||||
hwy::PreventElision(*pa);
|
hwy::PreventElision(*pa);
|
||||||
elapsed[time_rep] = hwy::platform::Now() - start;
|
elapsed[time_rep] = hwy::platform::Now() - start;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@
|
||||||
#include <cmath> // std::abs
|
#include <cmath> // std::abs
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "util/mat.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -37,6 +37,7 @@
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
|
#include "compression/compress-inl.h"
|
||||||
#include "ops/matvec-inl.h"
|
#include "ops/matvec-inl.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
#include "hwy/tests/test_util-inl.h"
|
||||||
|
|
||||||
|
|
@ -48,18 +49,18 @@ using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||||
|
|
||||||
FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
|
FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
|
||||||
const FloatPtr& add) {
|
const FloatPtr& add) {
|
||||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(mat.NumElements());
|
const size_t num = mat.Rows() * mat.Cols();
|
||||||
|
FloatPtr raw_mat = hwy::AllocateAligned<float>(num);
|
||||||
FloatPtr out = hwy::AllocateAligned<float>(mat.Rows());
|
FloatPtr out = hwy::AllocateAligned<float>(mat.Rows());
|
||||||
HWY_ASSERT(raw_mat && out);
|
HWY_ASSERT(raw_mat && out);
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
DecompressAndZeroPad(df, MakeSpan(mat.data(), mat.NumElements()), 0,
|
DecompressAndZeroPad(df, mat.Span(), 0, raw_mat.get(), num);
|
||||||
raw_mat.get(), mat.NumElements());
|
|
||||||
for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) {
|
for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) {
|
||||||
out[idx_row] = 0.0f;
|
out[idx_row] = 0.0f;
|
||||||
for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) {
|
for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) {
|
||||||
out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col];
|
out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col];
|
||||||
}
|
}
|
||||||
out[idx_row] *= mat.scale();
|
out[idx_row] *= mat.Scale();
|
||||||
out[idx_row] += add[idx_row];
|
out[idx_row] += add[idx_row];
|
||||||
}
|
}
|
||||||
return out;
|
return out;
|
||||||
|
|
@ -69,8 +70,10 @@ template <typename MatT, size_t kOuter, size_t kInner>
|
||||||
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
||||||
hwy::ThreadPool& pool) {
|
hwy::ThreadPool& pool) {
|
||||||
gcpp::CompressWorkingSet ws;
|
gcpp::CompressWorkingSet ws;
|
||||||
auto mat = std::make_unique<MatStorageT<float>>("TestMat", kOuter, kInner);
|
const Extents2D extents(kOuter, kInner);
|
||||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(mat->NumElements());
|
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents,
|
||||||
|
MatPadding::kPacked);
|
||||||
|
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
|
||||||
HWY_ASSERT(raw_mat);
|
HWY_ASSERT(raw_mat);
|
||||||
const float scale = 1.0f / kInner;
|
const float scale = 1.0f / kInner;
|
||||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||||
|
|
@ -80,8 +83,8 @@ std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
CompressScaled(raw_mat.get(), mat->NumElements(), ws, *mat, pool);
|
CompressScaled(raw_mat.get(), extents.Area(), ws, *mat, pool);
|
||||||
mat->set_scale(1.9f); // Arbitrary value, different from 1.
|
mat->SetScale(1.9f); // Arbitrary value, different from 1.
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdint.h>
|
#include <stdint.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -22,9 +23,8 @@
|
||||||
#include "ops/matmul.h" // IWYU pragma: export
|
#include "ops/matmul.h" // IWYU pragma: export
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#include "hwy/timer.h"
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
|
|
@ -866,6 +866,8 @@ class MMPerPackage {
|
||||||
const IndexRange& range_np)
|
const IndexRange& range_np)
|
||||||
: args_(args),
|
: args_(args),
|
||||||
pkg_idx_(pkg_idx),
|
pkg_idx_(pkg_idx),
|
||||||
|
// May be overwritten with a view of A, if already BF16.
|
||||||
|
A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())),
|
||||||
range_np_(range_np),
|
range_np_(range_np),
|
||||||
mr_(config.MR()),
|
mr_(config.MR()),
|
||||||
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
|
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
|
||||||
|
|
@ -873,15 +875,12 @@ class MMPerPackage {
|
||||||
ranges_nc_(config.RangesOfNC(range_np)),
|
ranges_nc_(config.RangesOfNC(range_np)),
|
||||||
order_(config.Order()),
|
order_(config.Order()),
|
||||||
inner_tasks_(config.InnerTasks()),
|
inner_tasks_(config.InnerTasks()),
|
||||||
out_(config.Out()) {
|
out_(config.Out()),
|
||||||
// May be overwritten with a view of A, if already BF16.
|
line_bytes_(args.env->ctx.allocator.LineBytes()) {
|
||||||
A_ = args_.env->storage.A(pkg_idx, A.Extents());
|
|
||||||
{
|
|
||||||
MMZone zone;
|
MMZone zone;
|
||||||
zone.MaybeEnter("MM.DecompressA", args_);
|
zone.MaybeEnter("MM.DecompressA", args_);
|
||||||
A_ = DecompressA(A);
|
A_ = DecompressA(A);
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// B is decompressed several call layers lower, but not all member functions
|
// B is decompressed several call layers lower, but not all member functions
|
||||||
// depend on TB, so pass it as an argument instead of templating the class.
|
// depend on TB, so pass it as an argument instead of templating the class.
|
||||||
|
|
@ -909,14 +908,14 @@ class MMPerPackage {
|
||||||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||||
// allocation avoids passing a worker index.
|
// allocation avoids passing a worker index.
|
||||||
static constexpr size_t B_stride_max_ =
|
static constexpr size_t B_stride_max_ =
|
||||||
StrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
|
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
|
||||||
static constexpr size_t B_storage_max_ =
|
static constexpr size_t B_storage_max_ =
|
||||||
kNR * B_stride_max_ + Allocator::MaxQuantumBytes() / sizeof(BF16);
|
kNR * B_stride_max_ + Allocator2::MaxQuantum<BF16>();
|
||||||
|
|
||||||
// Granularity of `ForNP`. B rows produce C columns, so we
|
// Granularity of `ForNP`. B rows produce C columns, so we
|
||||||
// want a multiple of the line size to prevent false sharing.
|
// want a multiple of the line size to prevent false sharing.
|
||||||
static size_t MultipleNP(size_t sizeof_TC) {
|
size_t MultipleNP(size_t sizeof_TC) const {
|
||||||
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC);
|
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Single M and K, parallel N. Fills all of C directly.
|
// Single M and K, parallel N. Fills all of C directly.
|
||||||
|
|
@ -931,14 +930,16 @@ class MMPerPackage {
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||||
const size_t K = range_K.Num();
|
const size_t K = range_K.Num();
|
||||||
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
|
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
|
||||||
const size_t B_stride = StrideForCyclicOffsets<BF16>(K);
|
const size_t B_stride =
|
||||||
|
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
|
||||||
|
|
||||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||||
args_.env->parallel.ForNP(
|
args_.env->parallel.ForNP(
|
||||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
|
||||||
|
B_stride);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
|
|
@ -972,7 +973,9 @@ class MMPerPackage {
|
||||||
auto out_tag) HWY_ATTR {
|
auto out_tag) HWY_ATTR {
|
||||||
const size_t kc = range_kc.Num();
|
const size_t kc = range_kc.Num();
|
||||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||||
const RowPtrBF B_view(B_storage, kc, StrideForCyclicOffsets<BF16>(kc));
|
const RowPtrBF B_view(
|
||||||
|
args_.env->ctx.allocator, B_storage, kc,
|
||||||
|
StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum<BF16>()));
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
|
|
@ -1027,7 +1030,8 @@ class MMPerPackage {
|
||||||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||||
const size_t K = range_K.Num();
|
const size_t K = range_K.Num();
|
||||||
const size_t B_stride = StrideForCyclicOffsets<BF16>(K);
|
const size_t B_stride =
|
||||||
|
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
|
||||||
|
|
||||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||||
// except for the profiler strings and `out_tag`.
|
// except for the profiler strings and `out_tag`.
|
||||||
|
|
@ -1036,7 +1040,8 @@ class MMPerPackage {
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
|
||||||
|
B_stride);
|
||||||
|
|
||||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||||
row_b += kNR) {
|
row_b += kNR) {
|
||||||
|
|
@ -1062,7 +1067,8 @@ class MMPerPackage {
|
||||||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||||
const size_t kc_max = ranges_kc_.TaskSize();
|
const size_t kc_max = ranges_kc_.TaskSize();
|
||||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||||
const size_t B_stride = StrideForCyclicOffsets<BF16>(kc_max);
|
const size_t B_stride = StrideForCyclicOffsets(
|
||||||
|
kc_max, args_.env->ctx.allocator.Quantum<BF16>());
|
||||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||||
|
|
@ -1088,7 +1094,8 @@ class MMPerPackage {
|
||||||
ranges_mc_, ranges_nc_, pkg_idx_,
|
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||||
const RowPtrBF B_view(B_storage, kc_max, B_stride);
|
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max,
|
||||||
|
B_stride);
|
||||||
|
|
||||||
// Peel off the first iteration of the kc loop: avoid
|
// Peel off the first iteration of the kc loop: avoid
|
||||||
// zero-initializing `partial` by writing into it.
|
// zero-initializing `partial` by writing into it.
|
||||||
|
|
@ -1151,8 +1158,7 @@ class MMPerPackage {
|
||||||
// At least one vector, otherwise DecompressAndZeroPad will add
|
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||||
// line to avoid false sharing.
|
// line to avoid false sharing.
|
||||||
const size_t multiple_K =
|
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
||||||
HWY_MAX(NBF, Allocator::LineBytes() / sizeof(BF16));
|
|
||||||
|
|
||||||
args_.env->parallel.ForNP(
|
args_.env->parallel.ForNP(
|
||||||
all_K, multiple_K, inner_tasks, pkg_idx_,
|
all_K, multiple_K, inner_tasks, pkg_idx_,
|
||||||
|
|
@ -1170,6 +1176,7 @@ class MMPerPackage {
|
||||||
// Autotuning wrapper for `DoDecompressA`.
|
// Autotuning wrapper for `DoDecompressA`.
|
||||||
template <typename TA>
|
template <typename TA>
|
||||||
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
|
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
|
||||||
|
const Allocator2& allocator = args_.env->ctx.allocator;
|
||||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||||
// If already BF16, maybe return a view:
|
// If already BF16, maybe return a view:
|
||||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||||
|
|
@ -1177,7 +1184,8 @@ class MMPerPackage {
|
||||||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||||
if (HWY_LIKELY(A.extents.cols % NBF == 0)) {
|
if (HWY_LIKELY(A.extents.cols % NBF == 0)) {
|
||||||
const BF16* pos = A.ptr + A.Row(0);
|
const BF16* pos = A.ptr + A.Row(0);
|
||||||
return RowPtrBF(const_cast<BF16*>(pos), A.extents.cols, A.Stride());
|
return RowPtrBF(allocator, const_cast<BF16*>(pos), A.extents.cols,
|
||||||
|
A.Stride());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1251,6 +1259,7 @@ class MMPerPackage {
|
||||||
const MMOrder order_;
|
const MMOrder order_;
|
||||||
const size_t inner_tasks_;
|
const size_t inner_tasks_;
|
||||||
const MMOut out_;
|
const MMOut out_;
|
||||||
|
const size_t line_bytes_;
|
||||||
}; // MMPerPackage
|
}; // MMPerPackage
|
||||||
|
|
||||||
// Stateless, wraps member functions.
|
// Stateless, wraps member functions.
|
||||||
|
|
@ -1308,6 +1317,7 @@ template <typename TA, typename TB, typename TC>
|
||||||
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
const RowPtr<TC>& C) {
|
const RowPtr<TC>& C) {
|
||||||
|
const Allocator2& allocator = env.ctx.allocator;
|
||||||
const size_t M = A.Extents().rows;
|
const size_t M = A.Extents().rows;
|
||||||
const size_t K = A.Extents().cols;
|
const size_t K = A.Extents().cols;
|
||||||
const size_t N = B.Extents().rows;
|
const size_t N = B.Extents().rows;
|
||||||
|
|
@ -1315,11 +1325,11 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
||||||
// First time we see this shape/key.
|
// First time we see this shape/key.
|
||||||
if (HWY_UNLIKELY(index < 0)) {
|
if (HWY_UNLIKELY(index < 0)) {
|
||||||
env.keys.Append(key);
|
env.keys.Append(key, allocator);
|
||||||
|
|
||||||
size_t max_packages = MMParallel::kMaxPackages;
|
size_t max_packages = MMParallel::kMaxPackages;
|
||||||
// For low-batch, multiple sockets only help if binding is enabled.
|
// For low-batch, multiple sockets only help if binding is enabled.
|
||||||
if (!Allocator::ShouldBind() && M <= 4) {
|
if (!allocator.ShouldBind() && M <= 4) {
|
||||||
max_packages = 1;
|
max_packages = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1351,8 +1361,9 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
HWY_ASSERT(N % kNR == 0);
|
HWY_ASSERT(N % kNR == 0);
|
||||||
|
|
||||||
// Negligible CPU time.
|
// Negligible CPU time.
|
||||||
tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR,
|
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC),
|
||||||
per_key.ranges_np, env.print_config));
|
MMKernel::kMaxMR, kNR, per_key.ranges_np,
|
||||||
|
env.print_config));
|
||||||
}
|
}
|
||||||
|
|
||||||
const MMConfig& cfg = tuner.NextConfig();
|
const MMConfig& cfg = tuner.NextConfig();
|
||||||
|
|
|
||||||
|
|
@ -60,10 +60,11 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
||||||
// and holds most of their arguments in member variables.
|
// and holds most of their arguments in member variables.
|
||||||
class GenerateCandidates {
|
class GenerateCandidates {
|
||||||
public:
|
public:
|
||||||
GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC,
|
GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N,
|
||||||
size_t max_mr, size_t nr,
|
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||||
const IndexRangePartition& ranges_np, bool print_config)
|
const IndexRangePartition& ranges_np, bool print_config)
|
||||||
: M_(M),
|
: allocator_(allocator),
|
||||||
|
M_(M),
|
||||||
K_(K),
|
K_(K),
|
||||||
N_(N),
|
N_(N),
|
||||||
sizeof_TC_(sizeof_TC),
|
sizeof_TC_(sizeof_TC),
|
||||||
|
|
@ -73,8 +74,8 @@ class GenerateCandidates {
|
||||||
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||||
// is likely still in L1, but we expect K > 1000 and might as well round
|
// is likely still in L1, but we expect K > 1000 and might as well round
|
||||||
// up to the line size.
|
// up to the line size.
|
||||||
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
|
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
|
||||||
nc_multiple_(Allocator::StepBytes() / sizeof_TC),
|
nc_multiple_(allocator.StepBytes() / sizeof_TC),
|
||||||
ranges_np_(ranges_np),
|
ranges_np_(ranges_np),
|
||||||
print_config_(print_config) {}
|
print_config_(print_config) {}
|
||||||
|
|
||||||
|
|
@ -172,7 +173,7 @@ class GenerateCandidates {
|
||||||
// subtract the output and buf, and allow using more than the actual L1
|
// subtract the output and buf, and allow using more than the actual L1
|
||||||
// size. This results in an overestimate, and the loop below will propose
|
// size. This results in an overestimate, and the loop below will propose
|
||||||
// the next few smaller values for the autotuner to evaluate.
|
// the next few smaller values for the autotuner to evaluate.
|
||||||
const size_t bytes_ab = Allocator::L1Bytes() * 3;
|
const size_t bytes_ab = allocator_.L1Bytes() * 3;
|
||||||
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
|
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
|
||||||
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
|
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
|
||||||
kc_max =
|
kc_max =
|
||||||
|
|
@ -220,8 +221,8 @@ class GenerateCandidates {
|
||||||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
||||||
// partial.
|
// partial.
|
||||||
const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes();
|
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
|
||||||
size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc);
|
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||||
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
|
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
|
||||||
HWY_DASSERT(mc_max != 0);
|
HWY_DASSERT(mc_max != 0);
|
||||||
mc_max = HWY_MIN(mc_max, M_);
|
mc_max = HWY_MIN(mc_max, M_);
|
||||||
|
|
@ -264,7 +265,7 @@ class GenerateCandidates {
|
||||||
// Otherwise, leave it unbounded.
|
// Otherwise, leave it unbounded.
|
||||||
if (M_ > mr) {
|
if (M_ > mr) {
|
||||||
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
|
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
|
||||||
nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc);
|
nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc);
|
||||||
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
|
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
|
||||||
}
|
}
|
||||||
HWY_DASSERT(nc_max != 0);
|
HWY_DASSERT(nc_max != 0);
|
||||||
|
|
@ -351,6 +352,7 @@ class GenerateCandidates {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const Allocator2& allocator_;
|
||||||
const size_t M_;
|
const size_t M_;
|
||||||
const size_t K_;
|
const size_t K_;
|
||||||
const size_t N_;
|
const size_t N_;
|
||||||
|
|
@ -370,25 +372,26 @@ class GenerateCandidates {
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||||
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
|
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
||||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
size_t K, size_t N, size_t sizeof_TC,
|
||||||
|
size_t max_mr, size_t nr,
|
||||||
const IndexRangePartition& ranges_np,
|
const IndexRangePartition& ranges_np,
|
||||||
bool print_config) {
|
bool print_config) {
|
||||||
return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np,
|
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
|
||||||
print_config)();
|
ranges_np, print_config)();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||||
// memory accesses or false sharing, unless there are insufficient per-package
|
// memory accesses or false sharing, unless there are insufficient per-package
|
||||||
// rows for that.
|
// rows for that.
|
||||||
static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
|
static size_t NPMultiple(const Allocator2& allocator, size_t N,
|
||||||
size_t num_packages) {
|
size_t sizeof_TC, size_t nr, size_t num_packages) {
|
||||||
size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC;
|
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
|
||||||
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
||||||
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
||||||
// choose a smaller multiple.
|
// choose a smaller multiple.
|
||||||
if (N % (np_multiple * num_packages)) {
|
if (N % (np_multiple * num_packages)) {
|
||||||
const size_t min_multiple = Allocator::LineBytes() / sizeof_TC;
|
const size_t min_multiple = allocator.LineBytes() / sizeof_TC;
|
||||||
np_multiple =
|
np_multiple =
|
||||||
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
|
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
|
||||||
if (HWY_UNLIKELY(np_multiple == 0)) {
|
if (HWY_UNLIKELY(np_multiple == 0)) {
|
||||||
|
|
@ -408,16 +411,14 @@ static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
|
||||||
|
|
||||||
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
||||||
size_t sizeof_TC, size_t nr) const {
|
size_t sizeof_TC, size_t nr) const {
|
||||||
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
|
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
|
||||||
return StaticPartition(IndexRange(0, N), num_packages,
|
return StaticPartition(
|
||||||
NPMultiple(N, sizeof_TC, nr, num_packages));
|
IndexRange(0, N), num_packages,
|
||||||
|
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
|
||||||
}
|
}
|
||||||
|
|
||||||
MatMulEnv::MatMulEnv(const BoundedTopology& topology, NestedPools& pools)
|
MatMulEnv::MatMulEnv(ThreadingContext2& ctx)
|
||||||
: parallel(topology, pools), storage(parallel) {
|
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
|
||||||
// Ensure Allocator:Init was called.
|
|
||||||
HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0);
|
|
||||||
|
|
||||||
char cpu100[100];
|
char cpu100[100];
|
||||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
130
ops/matmul.h
130
ops/matmul.h
|
|
@ -24,11 +24,9 @@
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/mat.h"
|
||||||
#include "util/topology.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/aligned_allocator.h" // Span
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/bit_set.h"
|
#include "hwy/bit_set.h"
|
||||||
|
|
@ -51,28 +49,25 @@ class MMParallel {
|
||||||
public:
|
public:
|
||||||
static constexpr size_t kMaxPackages = 4;
|
static constexpr size_t kMaxPackages = 4;
|
||||||
|
|
||||||
// Both references must outlive this object.
|
// `ctx` must outlive this object.
|
||||||
MMParallel(const BoundedTopology& topology, NestedPools& pools)
|
MMParallel(ThreadingContext2& ctx) : ctx_(ctx) {
|
||||||
: topology_(topology), pools_(pools) {
|
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
|
||||||
HWY_DASSERT(pools_.NumPackages() <= kMaxPackages);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Used by tests.
|
|
||||||
NestedPools& Pools() { return pools_; }
|
|
||||||
|
|
||||||
// Initial static partitioning of B rows across packages.
|
// Initial static partitioning of B rows across packages.
|
||||||
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
||||||
size_t sizeof_TC, size_t nr) const;
|
size_t sizeof_TC, size_t nr) const;
|
||||||
|
|
||||||
// For `BindB` and `BindC`.
|
// For `BindB` and `BindC`.
|
||||||
size_t Node(size_t pkg_idx) const {
|
size_t Node(size_t pkg_idx) const {
|
||||||
return topology_.GetCluster(pkg_idx, 0).Node();
|
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calls `func(pkg_idx)` for each package in parallel.
|
// Calls `func(pkg_idx)` for each package in parallel.
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForPkg(const size_t max_packages, const Func& func) {
|
void ForPkg(const size_t max_packages, const Func& func) {
|
||||||
pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()),
|
ctx_.pools.AllPackages().Run(
|
||||||
|
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
|
||||||
[&](uint64_t task, size_t pkg_idx) {
|
[&](uint64_t task, size_t pkg_idx) {
|
||||||
HWY_DASSERT(task == pkg_idx);
|
HWY_DASSERT(task == pkg_idx);
|
||||||
(void)task;
|
(void)task;
|
||||||
|
|
@ -87,10 +82,10 @@ class MMParallel {
|
||||||
size_t pkg_idx, const Func& func) {
|
size_t pkg_idx, const Func& func) {
|
||||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||||
// Single cluster: parallel-for over static partition of `range_np`.
|
// Single cluster: parallel-for over static partition of `range_np`.
|
||||||
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
|
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
if (num_clusters == 1) {
|
if (num_clusters == 1) {
|
||||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0);
|
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
|
|
@ -106,7 +101,7 @@ class MMParallel {
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
nx_ranges, all_clusters,
|
nx_ranges, all_clusters,
|
||||||
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||||
const IndexRangePartition worker_ranges = StaticPartition(
|
const IndexRangePartition worker_ranges = StaticPartition(
|
||||||
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||||
|
|
@ -122,14 +117,14 @@ class MMParallel {
|
||||||
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
||||||
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
||||||
const Func& func) {
|
const Func& func) {
|
||||||
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
|
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||||
const size_t num_clusters = all_clusters.NumWorkers();
|
const size_t num_clusters = all_clusters.NumWorkers();
|
||||||
// Single (big) cluster: collapse two range indices into one parallel-for
|
// Single (big) cluster: collapse two range indices into one parallel-for
|
||||||
// to reduce the number of fork-joins.
|
// to reduce the number of fork-joins.
|
||||||
if (num_clusters == 1) {
|
if (num_clusters == 1) {
|
||||||
const size_t cluster_idx = 0;
|
const size_t cluster_idx = 0;
|
||||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
// Low-batch: avoid Divide/Remainder.
|
// Low-batch: avoid Divide/Remainder.
|
||||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||||
return ParallelizeOneRange(
|
return ParallelizeOneRange(
|
||||||
|
|
@ -150,7 +145,7 @@ class MMParallel {
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
ranges_nc, all_clusters,
|
ranges_nc, all_clusters,
|
||||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
ranges_mc, cluster,
|
ranges_mc, cluster,
|
||||||
[&](const IndexRange& range_mc, size_t /*thread*/) {
|
[&](const IndexRange& range_mc, size_t /*thread*/) {
|
||||||
|
|
@ -163,32 +158,32 @@ class MMParallel {
|
||||||
template <class Func>
|
template <class Func>
|
||||||
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
||||||
const Func& func) {
|
const Func& func) {
|
||||||
pools_.Pool(pkg_idx).Run(
|
ctx_.pools.Pool(pkg_idx).Run(
|
||||||
range_mc.begin(), range_mc.end(),
|
range_mc.begin(), range_mc.end(),
|
||||||
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
|
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const BoundedTopology& topology_;
|
ThreadingContext2& ctx_;
|
||||||
NestedPools& pools_;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
template <typename TC> // BF16/float for C, double for partial
|
template <typename TC> // BF16/float for C, double for partial
|
||||||
void BindC(size_t M, const RowPtr<TC>& C, MMParallel& parallel) {
|
void BindC(const Allocator2& allocator, size_t M, const RowPtr<TC>& C,
|
||||||
if (!Allocator::ShouldBind()) return;
|
MMParallel& parallel) {
|
||||||
|
if (!allocator.ShouldBind()) return;
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
|
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
|
||||||
const size_t quantum = Allocator::QuantumBytes() / sizeof(TC);
|
const size_t quantum = allocator.Quantum<TC>();
|
||||||
bool ok = true;
|
bool ok = true;
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
for (size_t im = 0; im < M; ++im) {
|
for (size_t im = 0; im < M; ++im) {
|
||||||
// BindRowsToPackageNodes may not be page-aligned.
|
// `BindMemory` requires page alignment.
|
||||||
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
||||||
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
||||||
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
|
ok &= allocator.BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
|
||||||
node);
|
node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -212,38 +207,42 @@ class MMStorage {
|
||||||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||||
static constexpr size_t kMaxKC = 8 * 1024;
|
static constexpr size_t kMaxKC = 8 * 1024;
|
||||||
|
|
||||||
explicit MMStorage(MMParallel& parallel) {
|
MMStorage(const Allocator2& allocator, MMParallel& parallel)
|
||||||
|
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||||
|
// one instance of the maximum matrix extents because threads write at
|
||||||
|
// false-sharing-free granularity.
|
||||||
|
: partial_storage_(
|
||||||
|
AllocateAlignedRows<double>(allocator, Extents2D(kMaxM, kMaxN))),
|
||||||
|
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||||
|
partial_(allocator, partial_storage_.All(), kMaxN,
|
||||||
|
StrideForCyclicOffsets(kMaxN, allocator.Quantum<double>())) {
|
||||||
// Per-package allocation so each can decompress A into its own copy.
|
// Per-package allocation so each can decompress A into its own copy.
|
||||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||||
pkg_A_[pkg_idx] = AllocateAlignedRows<BF16>(Extents2D(kMaxM, kMaxK));
|
pkg_A_[pkg_idx] =
|
||||||
|
AllocateAlignedRows<BF16>(allocator, Extents2D(kMaxM, kMaxK));
|
||||||
|
|
||||||
if (Allocator::ShouldBind()) {
|
if (allocator.ShouldBind()) {
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(),
|
if (!allocator.BindMemory(pkg_A_[pkg_idx].All(),
|
||||||
pkg_A_[pkg_idx].NumBytes(), node)) {
|
pkg_A_[pkg_idx].NumBytes(), node)) {
|
||||||
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
|
||||||
// one instance of the maximum matrix extents because threads write at
|
|
||||||
// false-sharing-free granularity.
|
|
||||||
partial_storage_ = AllocateAlignedRows<double>(Extents2D(kMaxM, kMaxN));
|
|
||||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
|
||||||
partial_ = RowPtrD(partial_storage_.All(), kMaxN,
|
|
||||||
StrideForCyclicOffsets<double>(kMaxN));
|
|
||||||
// Avoid cross-package accesses.
|
// Avoid cross-package accesses.
|
||||||
BindC(kMaxM, partial_, parallel);
|
BindC(allocator, kMaxM, partial_, parallel);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
|
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
|
||||||
// non-const, because `RowPtr` requires a non-const pointer.
|
// non-const, because `RowPtr` requires a non-const pointer.
|
||||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) {
|
RowPtrBF A(const Allocator2& allocator, size_t pkg_idx,
|
||||||
|
const Extents2D& extents) {
|
||||||
HWY_DASSERT(extents.rows <= kMaxM);
|
HWY_DASSERT(extents.rows <= kMaxM);
|
||||||
HWY_DASSERT(extents.cols <= kMaxK);
|
HWY_DASSERT(extents.cols <= kMaxK);
|
||||||
const size_t stride = StrideForCyclicOffsets<BF16>(extents.cols);
|
const size_t stride =
|
||||||
return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride);
|
StrideForCyclicOffsets(extents.cols, allocator.Quantum<BF16>());
|
||||||
|
return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
RowPtrD Partial() const { return partial_; }
|
RowPtrD Partial() const { return partial_; }
|
||||||
|
|
@ -431,13 +430,15 @@ class MMConfig {
|
||||||
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||||
#pragma pack(pop)
|
#pragma pack(pop)
|
||||||
|
|
||||||
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
|
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
||||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
size_t K, size_t N, size_t sizeof_TC,
|
||||||
|
size_t max_mr, size_t nr,
|
||||||
const IndexRangePartition& ranges_np,
|
const IndexRangePartition& ranges_np,
|
||||||
bool print_config);
|
bool print_config);
|
||||||
|
|
||||||
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
||||||
// main MatMul autotuner.
|
// main MatMul autotuner.
|
||||||
|
// TODO: replace with hwy/auto_tune.h.
|
||||||
template <typename TConfig>
|
template <typename TConfig>
|
||||||
class MMAutoTune {
|
class MMAutoTune {
|
||||||
public:
|
public:
|
||||||
|
|
@ -560,11 +561,11 @@ class MMKeys {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Must only be called if not already present in `Keys()`.
|
// Must only be called if not already present in `Keys()`.
|
||||||
void Append(Key key) {
|
void Append(Key key, const Allocator2& allocator) {
|
||||||
// Dynamic allocation because the test checks many more dimensions than
|
// Dynamic allocation because the test checks many more dimensions than
|
||||||
// would be reasonable to pre-allocate. DIY for alignment and padding.
|
// would be reasonable to pre-allocate. DIY for alignment and padding.
|
||||||
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
|
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
|
||||||
const size_t NU64 = Allocator::VectorBytes() / sizeof(Key);
|
const size_t NU64 = allocator.VectorBytes() / sizeof(Key);
|
||||||
// Start at one vector so the size is always a multiple of N.
|
// Start at one vector so the size is always a multiple of N.
|
||||||
if (HWY_UNLIKELY(capacity_ == 0)) {
|
if (HWY_UNLIKELY(capacity_ == 0)) {
|
||||||
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
|
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
|
||||||
|
|
@ -604,10 +605,12 @@ struct MMPerKey {
|
||||||
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
|
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
|
||||||
};
|
};
|
||||||
|
|
||||||
// Stores state shared across MatMul calls. Non-copyable.
|
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
|
||||||
|
// `MatMulEnv`.
|
||||||
struct MatMulEnv {
|
struct MatMulEnv {
|
||||||
explicit MatMulEnv(const BoundedTopology& topology, NestedPools& pools);
|
explicit MatMulEnv(ThreadingContext2& ctx);
|
||||||
|
|
||||||
|
ThreadingContext2& ctx;
|
||||||
bool have_timer_stop = false;
|
bool have_timer_stop = false;
|
||||||
|
|
||||||
// Enable binding: disabled in Gemma until tensors support it, enabled in
|
// Enable binding: disabled in Gemma until tensors support it, enabled in
|
||||||
|
|
@ -684,8 +687,9 @@ struct MMZone {
|
||||||
// `ofs` required for compressed T.
|
// `ofs` required for compressed T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
struct ConstMat {
|
struct ConstMat {
|
||||||
ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0)
|
ConstMat() = default;
|
||||||
: ptr(ptr), extents(extents), stride(stride), ofs(ofs) {
|
ConstMat(const T* ptr, Extents2D extents, size_t stride)
|
||||||
|
: ptr(ptr), extents(extents), stride(stride), ofs(0) {
|
||||||
HWY_DASSERT(ptr != nullptr);
|
HWY_DASSERT(ptr != nullptr);
|
||||||
HWY_DASSERT(stride >= extents.cols);
|
HWY_DASSERT(stride >= extents.cols);
|
||||||
}
|
}
|
||||||
|
|
@ -717,15 +721,17 @@ struct ConstMat {
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
|
|
||||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||||
// pointer arithmetic.
|
// pointer arithmetic. This is in units of weights, and does not have anything
|
||||||
|
// to do with the interleaved NUQ tables. It should be computed via `Row()`
|
||||||
|
// to take into account the stride.
|
||||||
size_t ofs;
|
size_t ofs;
|
||||||
};
|
};
|
||||||
|
|
||||||
// For deducing T.
|
// For deducing T.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride,
|
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
||||||
size_t ofs = 0) {
|
size_t stride) {
|
||||||
return ConstMat<T>(ptr, extents, stride, ofs);
|
return ConstMat<T>(ptr, extents, stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
// For A argument to MatMul (activations).
|
// For A argument to MatMul (activations).
|
||||||
|
|
@ -739,21 +745,21 @@ ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
|
||||||
ConstMat<T> mat =
|
ConstMat<T> mat =
|
||||||
MakeConstMat(const_cast<T*>(m.data()), m.Extents(), m.Stride(), ofs);
|
MakeConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride());
|
||||||
mat.scale = m.scale();
|
mat.scale = m.Scale();
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename TB>
|
template <typename TB>
|
||||||
void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
|
void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC,
|
||||||
MMParallel& parallel) {
|
const ConstMat<TB>& B, MMParallel& parallel) {
|
||||||
if (!Allocator::ShouldBind()) return;
|
if (!allocator.ShouldBind()) return;
|
||||||
|
|
||||||
const IndexRangePartition ranges_np =
|
const IndexRangePartition ranges_np =
|
||||||
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
|
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
|
||||||
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
|
const size_t quantum = allocator.Quantum<TB>();
|
||||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||||
const size_t node = parallel.Node(pkg_idx);
|
const size_t node = parallel.Node(pkg_idx);
|
||||||
|
|
@ -765,7 +771,7 @@ void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
|
||||||
begin = hwy::RoundUpTo(begin, quantum);
|
begin = hwy::RoundUpTo(begin, quantum);
|
||||||
end = hwy::RoundDownTo(end, quantum);
|
end = hwy::RoundDownTo(end, quantum);
|
||||||
if (HWY_LIKELY(begin != end)) {
|
if (HWY_LIKELY(begin != end)) {
|
||||||
Allocator::BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -15,29 +15,25 @@
|
||||||
|
|
||||||
// End to end test of MatMul, comparing against a reference implementation.
|
// End to end test of MatMul, comparing against a reference implementation.
|
||||||
|
|
||||||
#include "hwy/detect_compiler_arch.h"
|
#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
||||||
// double-precision support.
|
// double-precision support, and older x86 to speed up builds.
|
||||||
#if HWY_ARCH_ARM_V7
|
#if HWY_ARCH_ARM_V7
|
||||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
|
||||||
#else
|
#else
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4)
|
||||||
#endif
|
#endif
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
|
|
||||||
#include "compression/compress.h"
|
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "ops/matmul.h"
|
#include "ops/matmul.h"
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "util/threading.h"
|
#include "util/mat.h"
|
||||||
#include "hwy/base.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
@ -48,9 +44,9 @@
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
|
#include "compression/test_util-inl.h"
|
||||||
#include "ops/dot-inl.h"
|
#include "ops/dot-inl.h"
|
||||||
#include "ops/matmul-inl.h"
|
#include "ops/matmul-inl.h"
|
||||||
#include "hwy/tests/test_util-inl.h"
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -60,57 +56,6 @@ extern int64_t first_target;
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
|
||||||
|
|
||||||
template <typename MatT>
|
|
||||||
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
|
||||||
|
|
||||||
// Generates inputs: deterministic, within max SfpStream range.
|
|
||||||
template <typename MatT>
|
|
||||||
MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
gcpp::CompressWorkingSet ws;
|
|
||||||
auto mat =
|
|
||||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
|
||||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
|
||||||
HWY_ASSERT(content);
|
|
||||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
|
||||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
|
||||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
|
||||||
content[r * extents.cols + c] = f;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
|
||||||
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
// extents describes the transposed matrix.
|
|
||||||
template <typename MatT>
|
|
||||||
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
|
||||||
hwy::ThreadPool& pool) {
|
|
||||||
gcpp::CompressWorkingSet ws;
|
|
||||||
auto mat =
|
|
||||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
|
||||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
|
||||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
|
||||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
|
||||||
for (size_t c = 0; c < extents.cols; c++) {
|
|
||||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
|
||||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
|
||||||
content[r * extents.cols + c] = f;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
|
||||||
// Arbitrary value, different from 1, must match GenerateMat.
|
|
||||||
mat->set_scale(0.6f);
|
|
||||||
return mat;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||||
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
|
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
|
||||||
double max_row_abs_sum = 0.0;
|
double max_row_abs_sum = 0.0;
|
||||||
|
|
@ -141,16 +86,19 @@ float MaxAbs(const RowVectorBatch<float>& a) {
|
||||||
template <typename TA, typename TB, typename TC>
|
template <typename TA, typename TB, typename TC>
|
||||||
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||||
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
||||||
|
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const size_t cols = A.extents.cols;
|
const size_t cols = A.extents.cols;
|
||||||
const size_t B_rows = B.extents.rows;
|
const size_t B_rows = B.extents.rows;
|
||||||
// Round up for DecompressAndZeroPad.
|
// Round up for DecompressAndZeroPad.
|
||||||
RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
|
RowVectorBatch<float> a_batch =
|
||||||
RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
|
AllocateAlignedRows<float>(allocator, A.extents);
|
||||||
|
RowVectorBatch<float> b_trans_batch =
|
||||||
|
AllocateAlignedRows<float>(allocator, B.extents);
|
||||||
RowVectorBatch<float> c_batch =
|
RowVectorBatch<float> c_batch =
|
||||||
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
|
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
|
||||||
RowVectorBatch<float> c_slow_batch =
|
RowVectorBatch<float> c_slow_batch =
|
||||||
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
|
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
|
||||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||||
for (size_t m = 0; m < A.extents.rows; ++m) {
|
for (size_t m = 0; m < A.extents.rows; ++m) {
|
||||||
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
|
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
|
||||||
|
|
@ -224,7 +172,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||||
const IndexRange all_rows_c(0, A.Extents().rows);
|
const IndexRange all_rows_c(0, A.Extents().rows);
|
||||||
const IndexRange all_cols_c(0, C.Cols());
|
const IndexRange all_cols_c(0, C.Cols());
|
||||||
|
|
||||||
NestedPools& pools = env.parallel.Pools();
|
NestedPools& pools = env.ctx.pools;
|
||||||
hwy::ThreadPool& all_packages = pools.AllPackages();
|
hwy::ThreadPool& all_packages = pools.AllPackages();
|
||||||
const IndexRangePartition get_row_c =
|
const IndexRangePartition get_row_c =
|
||||||
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
|
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
|
||||||
|
|
@ -232,7 +180,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
||||||
get_row_c, all_packages,
|
get_row_c, all_packages,
|
||||||
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
||||||
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
||||||
const size_t multiple = Allocator::QuantumBytes() / sizeof(TB);
|
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
|
||||||
const IndexRangePartition get_col_c =
|
const IndexRangePartition get_col_c =
|
||||||
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||||
ParallelizeOneRange(
|
ParallelizeOneRange(
|
||||||
|
|
@ -262,7 +210,8 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
||||||
template <typename TA, typename TB = TA, typename TC = float>
|
template <typename TA, typename TB = TA, typename TC = float>
|
||||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
MatMulEnv& env, int line) {
|
MatMulEnv& env, int line) {
|
||||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
const Allocator2& allocator = env.ctx.allocator;
|
||||||
|
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||||
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
||||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
||||||
TypeName<TC>());
|
TypeName<TC>());
|
||||||
|
|
@ -274,24 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||||
const Extents2D C_extents(rows_ac, cols_bc);
|
const Extents2D C_extents(rows_ac, cols_bc);
|
||||||
|
|
||||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
|
||||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
|
||||||
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
|
RowVectorBatch<TC> c_slow_batch =
|
||||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
|
AllocateAlignedRows<TC>(allocator, C_extents);
|
||||||
HWY_ASSERT(a && b_trans);
|
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
|
||||||
|
|
||||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
MatStorageT<float> add_storage =
|
||||||
if (add) {
|
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
||||||
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
|
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
|
||||||
HWY_ASSERT(add_storage);
|
add_storage.SetScale(1.0f);
|
||||||
add_storage->set_scale(1.0f);
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto A = ConstMatFromWeights(*a);
|
const auto A = ConstMatFromWeights(a);
|
||||||
const auto B = ConstMatFromWeights(*b_trans);
|
const auto B = ConstMatFromWeights(b_trans);
|
||||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||||
const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
|
const RowPtr<TC> C_slow = RowPtrFromBatch(allocator, c_slow_batch);
|
||||||
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
|
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
|
||||||
|
|
||||||
MatMulSlow(A, B, add_row, env, C_slow);
|
MatMulSlow(A, B, add_row, env, C_slow);
|
||||||
// A few reps to get coverage of the various autotuned code paths.
|
// A few reps to get coverage of the various autotuned code paths.
|
||||||
|
|
@ -312,22 +259,24 @@ void TestTiny() {
|
||||||
if (HWY_TARGET != first_target) return;
|
if (HWY_TARGET != first_target) return;
|
||||||
|
|
||||||
for (size_t max_packages : {1, 2}) {
|
for (size_t max_packages : {1, 2}) {
|
||||||
const BoundedTopology topology(BoundedSlice(0, max_packages));
|
ThreadingContext2::ThreadHostileInvalidate();
|
||||||
Allocator::Init(topology, /*enable_bind=*/true);
|
ThreadingArgs threading_args;
|
||||||
const size_t max_threads = 0; // no limit
|
threading_args.bind = Tristate::kTrue;
|
||||||
NestedPools pools(topology, max_threads, Tristate::kDefault);
|
threading_args.max_packages = max_packages;
|
||||||
|
ThreadingContext2::SetArgs(threading_args);
|
||||||
|
MatMulEnv env(ThreadingContext2::Get());
|
||||||
|
NestedPools& pools = env.ctx.pools;
|
||||||
|
|
||||||
#if GEMMA_DISABLE_TOPOLOGY
|
#if GEMMA_DISABLE_TOPOLOGY
|
||||||
if (max_packages == 2) break; // we only have one package
|
if (max_packages == 2) break; // we only have one package
|
||||||
#else
|
#else
|
||||||
// If less than the limit, we have already tested all num_packages.
|
// If less than the limit, we have already tested all num_packages.
|
||||||
if (topology.FullTopology().packages.size() < max_packages) break;
|
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
|
||||||
#endif
|
#endif
|
||||||
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
|
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
|
||||||
topology.TopologyString(), pools.PinString());
|
env.ctx.topology.TopologyString(), pools.PinString());
|
||||||
|
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
pools.MaybeStartSpinning(threading_args.spin);
|
||||||
pools.MaybeStartSpinning(use_spinning);
|
|
||||||
MatMulEnv env(topology, pools);
|
|
||||||
|
|
||||||
for (size_t M = 1; M <= 12; ++M) {
|
for (size_t M = 1; M <= 12; ++M) {
|
||||||
for (size_t K = 1; K <= 64; K *= 2) {
|
for (size_t K = 1; K <= 64; K *= 2) {
|
||||||
|
|
@ -336,7 +285,7 @@ void TestTiny() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
pools.MaybeStopSpinning(use_spinning);
|
pools.MaybeStopSpinning(threading_args.spin);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -347,12 +296,13 @@ void TestAllMatMul() {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
const BoundedTopology topology;
|
ThreadingContext2::ThreadHostileInvalidate();
|
||||||
Allocator::Init(topology, /*enable_bind=*/true);
|
ThreadingArgs threading_args;
|
||||||
NestedPools pools(topology);
|
threading_args.bind = Tristate::kTrue;
|
||||||
Tristate use_spinning = Tristate::kDefault;
|
ThreadingContext2::SetArgs(threading_args);
|
||||||
pools.MaybeStartSpinning(use_spinning);
|
MatMulEnv env(ThreadingContext2::Get());
|
||||||
MatMulEnv env(topology, pools);
|
NestedPools& pools = env.ctx.pools;
|
||||||
|
pools.MaybeStartSpinning(threading_args.spin);
|
||||||
|
|
||||||
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
|
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
|
||||||
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
|
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
|
||||||
|
|
@ -417,6 +367,8 @@ void TestAllMatMul() {
|
||||||
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
|
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
|
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
|
||||||
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
|
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||||
|
|
||||||
|
pools.MaybeStopSpinning(threading_args.spin);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -50,8 +50,7 @@ template <class ArrayT, typename VT>
|
||||||
HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned,
|
HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
const hn::ScalableTag<VT> d;
|
const hn::ScalableTag<VT> d;
|
||||||
return w.scale() * Dot(d, MakeConstSpan(w.data(), w.NumElements()), w_ofs,
|
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
||||||
vec_aligned, num);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||||
|
|
|
||||||
|
|
@ -27,12 +27,13 @@
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "util/allocator.h"
|
||||||
#include "util/basics.h" // TokenAndProb
|
#include "util/basics.h" // TokenAndProb
|
||||||
|
#include "util/mat.h"
|
||||||
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/sort/order.h"
|
#include "hwy/contrib/sort/order.h"
|
||||||
#include "hwy/contrib/sort/vqsort.h"
|
#include "hwy/contrib/sort/vqsort.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
#include "hwy/profiler.h"
|
#include "hwy/profiler.h"
|
||||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||||
|
|
@ -807,12 +808,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
||||||
// Each output row is the average of a 4x4 block of input rows
|
// Each output row is the average of a 4x4 block of input rows
|
||||||
template <typename T>
|
template <typename T>
|
||||||
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
||||||
Extents2D extents = input.Extents();
|
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||||
|
const Extents2D extents = input.Extents();
|
||||||
// Input validation
|
// Input validation
|
||||||
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||||
// Create output with 256 rows and same number of columns
|
// Create output with 256 rows and same number of columns
|
||||||
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
||||||
RowVectorBatch<T> result(Extents2D{out_rows, extents.cols});
|
RowVectorBatch<T> result(allocator, Extents2D(out_rows, extents.cols));
|
||||||
const size_t input_dim = 64; // Input is 64×64
|
const size_t input_dim = 64; // Input is 64×64
|
||||||
const size_t output_dim = 16; // Output is 16×16
|
const size_t output_dim = 16; // Output is 16×16
|
||||||
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
|
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
|
||||||
|
|
|
||||||
|
|
@ -21,14 +21,16 @@
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "util/mat.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
|
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
|
||||||
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
|
const Allocator2& allocator, size_t qkv_dim, bool half_rope,
|
||||||
|
double base_frequency = 10000.0) {
|
||||||
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
||||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));
|
||||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||||
const double freq_exponents =
|
const double freq_exponents =
|
||||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||||
|
|
|
||||||
|
|
@ -31,14 +31,12 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "compression/compress.h" // BF16
|
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
#include "util/app.h"
|
#include "util/basics.h" // BF16
|
||||||
|
#include "util/mat.h" // RowVectorBatch
|
||||||
#include "util/test_util.h"
|
#include "util/test_util.h"
|
||||||
#include "util/threading.h"
|
#include "util/threading_context.h"
|
||||||
#include "hwy/base.h"
|
|
||||||
#include "hwy/tests/hwy_gtest.h"
|
#include "hwy/tests/hwy_gtest.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
|
|
@ -388,13 +386,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestRopeAndMulBy() {
|
void TestRopeAndMulBy() {
|
||||||
AppArgs app;
|
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||||
BoundedTopology topology = CreateTopology(app);
|
|
||||||
NestedPools pools = CreatePools(topology, app);
|
|
||||||
|
|
||||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||||
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
|
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv));
|
||||||
|
|
||||||
std::mt19937 gen;
|
std::mt19937 gen;
|
||||||
gen.seed(0x12345678);
|
gen.seed(0x12345678);
|
||||||
|
|
@ -412,8 +408,8 @@ void TestRopeAndMulBy() {
|
||||||
std::vector<float> qactual(dim_qkv);
|
std::vector<float> qactual(dim_qkv);
|
||||||
std::vector<float> kexpected(dim_qkv);
|
std::vector<float> kexpected(dim_qkv);
|
||||||
std::vector<float> kactual(dim_qkv);
|
std::vector<float> kactual(dim_qkv);
|
||||||
RowVectorBatch<float> inv_timescale = gcpp::CreateInvTimescale(
|
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||||
config.layer_configs[0].qkv_dim,
|
allocator, config.layer_configs[0].qkv_dim,
|
||||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||||
for (int pos = 1; pos < 500; pos++) {
|
for (int pos = 1; pos < 500; pos++) {
|
||||||
|
|
|
||||||
|
|
@ -49,8 +49,8 @@ class PaliGemmaTest : public ::testing::Test {
|
||||||
};
|
};
|
||||||
|
|
||||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
Gemma& model = *(s_env->GetModel());
|
Gemma& model = *(s_env->GetGemma());
|
||||||
image_tokens_ =
|
image_tokens_ =
|
||||||
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||||
model.GetModelConfig().model_dim));
|
model.GetModelConfig().model_dim));
|
||||||
|
|
@ -64,7 +64,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||||
Gemma& model = *(s_env->GetModel());
|
Gemma& model = *(s_env->GetGemma());
|
||||||
s_env->MutableGen().seed(0x12345678);
|
s_env->MutableGen().seed(0x12345678);
|
||||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||||
.gen = &s_env->MutableGen(),
|
.gen = &s_env->MutableGen(),
|
||||||
|
|
@ -92,7 +92,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||||
}
|
}
|
||||||
|
|
||||||
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
std::string path = "paligemma/testdata/image.ppm";
|
std::string path = "paligemma/testdata/image.ppm";
|
||||||
InitVit(path);
|
InitVit(path);
|
||||||
for (size_t i = 0; i < num_questions; ++i) {
|
for (size_t i = 0; i < num_questions; ++i) {
|
||||||
|
|
@ -104,7 +104,7 @@ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||||
}
|
}
|
||||||
|
|
||||||
TEST_F(PaliGemmaTest, General) {
|
TEST_F(PaliGemmaTest, General) {
|
||||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||||
static const char* kQA_3B_mix_224[][2] = {
|
static const char* kQA_3B_mix_224[][2] = {
|
||||||
{"describe this image",
|
{"describe this image",
|
||||||
"A large building with two towers stands tall on the water's edge."},
|
"A large building with two towers stands tall on the water's edge."},
|
||||||
|
|
@ -124,7 +124,7 @@ TEST_F(PaliGemmaTest, General) {
|
||||||
};
|
};
|
||||||
const char* (*qa)[2];
|
const char* (*qa)[2];
|
||||||
size_t num;
|
size_t num;
|
||||||
switch (s_env->GetModel()->Info().model) {
|
switch (s_env->GetGemma()->Info().model) {
|
||||||
case Model::PALIGEMMA_224:
|
case Model::PALIGEMMA_224:
|
||||||
qa = kQA_3B_mix_224;
|
qa = kQA_3B_mix_224;
|
||||||
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
|
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
|
||||||
|
|
@ -135,7 +135,7 @@ TEST_F(PaliGemmaTest, General) {
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
FAIL() << "Unsupported model: "
|
FAIL() << "Unsupported model: "
|
||||||
<< s_env->GetModel()->GetModelConfig().model_name;
|
<< s_env->GetGemma()->GetModelConfig().model_name;
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
TestQuestions(qa, num);
|
TestQuestions(qa, num);
|
||||||
|
|
|
||||||
|
|
@ -21,12 +21,12 @@ pybind_extension(
|
||||||
name = "gemma",
|
name = "gemma",
|
||||||
srcs = ["gemma_py.cc"],
|
srcs = ["gemma_py.cc"],
|
||||||
deps = [
|
deps = [
|
||||||
"//:app",
|
"//:allocator",
|
||||||
"//:benchmark_helper",
|
"//:benchmark_helper",
|
||||||
|
"//:gemma_args",
|
||||||
"//:gemma_lib",
|
"//:gemma_lib",
|
||||||
"//compression:sfp",
|
"//compression:sfp",
|
||||||
"@highway//:hwy",
|
"@highway//:hwy",
|
||||||
"@highway//:thread_pool",
|
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -32,9 +32,9 @@
|
||||||
#include "compression/shared.h"
|
#include "compression/shared.h"
|
||||||
#include "evals/benchmark_helper.h"
|
#include "evals/benchmark_helper.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
#include "util/app.h"
|
#include "gemma/gemma_args.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
|
||||||
|
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
|
|
||||||
|
|
@ -48,8 +48,9 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
|
||||||
class GemmaModel {
|
class GemmaModel {
|
||||||
public:
|
public:
|
||||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||||
const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app)
|
const gcpp::InferenceArgs& inference,
|
||||||
: gemma_(loader, inference, app), last_prob_(0.0f) {}
|
const gcpp::ThreadingArgs& threading)
|
||||||
|
: gemma_(threading, loader, inference), last_prob_(0.0f) {}
|
||||||
|
|
||||||
// Generates a single example, given a prompt and a callback to stream the
|
// Generates a single example, given a prompt and a callback to stream the
|
||||||
// generated tokens.
|
// generated tokens.
|
||||||
|
|
@ -168,7 +169,8 @@ class GemmaModel {
|
||||||
// Generate* will use this image. Throws an error for other models.
|
// Generate* will use this image. Throws an error for other models.
|
||||||
void SetImage(const py::array_t<float, py::array::c_style |
|
void SetImage(const py::array_t<float, py::array::c_style |
|
||||||
py::array::forcecast>& image) {
|
py::array::forcecast>& image) {
|
||||||
gcpp::Gemma& model = *(gemma_.GetModel());
|
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
|
||||||
|
gcpp::Gemma& model = *(gemma_.GetGemma());
|
||||||
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
|
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
|
||||||
throw std::invalid_argument("Not a PaliGemma model.");
|
throw std::invalid_argument("Not a PaliGemma model.");
|
||||||
}
|
}
|
||||||
|
|
@ -183,8 +185,8 @@ class GemmaModel {
|
||||||
c_image.Set(height, width, ptr);
|
c_image.Set(height, width, ptr);
|
||||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||||
c_image.Resize(image_size, image_size);
|
c_image.Resize(image_size, image_size);
|
||||||
image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D(
|
image_tokens_ = gcpp::ImageTokens(
|
||||||
model.GetModelConfig().vit_config.seq_len,
|
allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||||
model.GetModelConfig().model_dim));
|
model.GetModelConfig().model_dim));
|
||||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||||
.verbosity = 0};
|
.verbosity = 0};
|
||||||
|
|
@ -199,7 +201,7 @@ class GemmaModel {
|
||||||
if (image_tokens_.Cols() == 0) {
|
if (image_tokens_.Cols() == 0) {
|
||||||
throw std::invalid_argument("No image set.");
|
throw std::invalid_argument("No image set.");
|
||||||
}
|
}
|
||||||
gcpp::Gemma& model = *(gemma_.GetModel());
|
gcpp::Gemma& model = *(gemma_.GetGemma());
|
||||||
gemma_.MutableGen().seed(seed);
|
gemma_.MutableGen().seed(seed);
|
||||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||||
config.max_generated_tokens = max_generated_tokens;
|
config.max_generated_tokens = max_generated_tokens;
|
||||||
|
|
@ -247,7 +249,7 @@ class GemmaModel {
|
||||||
return gemma_.StringFromTokens(token_ids);
|
return gemma_.StringFromTokens(token_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; }
|
bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; }
|
||||||
|
|
||||||
private:
|
private:
|
||||||
gcpp::GemmaEnv gemma_;
|
gcpp::GemmaEnv gemma_;
|
||||||
|
|
@ -267,7 +269,7 @@ PYBIND11_MODULE(gemma, mod) {
|
||||||
loader.weight_type_str = weight_type;
|
loader.weight_type_str = weight_type;
|
||||||
gcpp::InferenceArgs inference;
|
gcpp::InferenceArgs inference;
|
||||||
inference.max_generated_tokens = 512;
|
inference.max_generated_tokens = 512;
|
||||||
gcpp::AppArgs app;
|
gcpp::ThreadingArgs app;
|
||||||
app.max_threads = max_threads;
|
app.max_threads = max_threads;
|
||||||
auto gemma =
|
auto gemma =
|
||||||
std::make_unique<GemmaModel>(loader, inference, app);
|
std::make_unique<GemmaModel>(loader, inference, app);
|
||||||
|
|
|
||||||
|
|
@ -130,233 +130,6 @@ size_t DetectTotalMiB(size_t page_bytes) {
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
static size_t line_bytes_;
|
|
||||||
static size_t vector_bytes_;
|
|
||||||
static size_t step_bytes_;
|
|
||||||
static size_t quantum_bytes_;
|
|
||||||
static size_t quantum_steps_;
|
|
||||||
static size_t l1_bytes_;
|
|
||||||
static size_t l2_bytes_;
|
|
||||||
static size_t l3_bytes_;
|
|
||||||
static bool should_bind_ = false;
|
|
||||||
|
|
||||||
size_t Allocator::LineBytes() { return line_bytes_; }
|
|
||||||
size_t Allocator::VectorBytes() { return vector_bytes_; }
|
|
||||||
size_t Allocator::StepBytes() { return step_bytes_; }
|
|
||||||
size_t Allocator::QuantumBytes() { return quantum_bytes_; }
|
|
||||||
size_t Allocator::QuantumSteps() { return quantum_steps_; }
|
|
||||||
size_t Allocator::L1Bytes() { return l1_bytes_; }
|
|
||||||
size_t Allocator::L2Bytes() { return l2_bytes_; }
|
|
||||||
size_t Allocator::L3Bytes() { return l3_bytes_; }
|
|
||||||
bool Allocator::ShouldBind() { return should_bind_; }
|
|
||||||
|
|
||||||
void Allocator::Init(const BoundedTopology& topology, bool enable_bind) {
|
|
||||||
line_bytes_ = DetectLineBytes();
|
|
||||||
vector_bytes_ = hwy::VectorBytes();
|
|
||||||
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
|
||||||
quantum_bytes_ = step_bytes_; // may overwrite below
|
|
||||||
|
|
||||||
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
|
|
||||||
if (const hwy::Cache* caches = hwy::DataCaches()) {
|
|
||||||
l1_bytes_ = caches[1].size_kib << 10;
|
|
||||||
l2_bytes_ = caches[2].size_kib << 10;
|
|
||||||
l3_bytes_ = (caches[3].size_kib << 10) * caches[3].cores_sharing;
|
|
||||||
} else { // Unknown, make reasonable assumptions.
|
|
||||||
l1_bytes_ = 32 << 10;
|
|
||||||
l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10;
|
|
||||||
}
|
|
||||||
if (l3_bytes_ == 0) {
|
|
||||||
l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prerequisites for binding:
|
|
||||||
// - supported by the OS (currently Linux only),
|
|
||||||
// - the page size is known and 'reasonably small', preferably less than
|
|
||||||
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
|
|
||||||
// - we successfully detected topology and there are multiple nodes;
|
|
||||||
// - there are multiple packages, because we shard by package_idx.
|
|
||||||
if constexpr (GEMMA_BIND) {
|
|
||||||
const size_t page_bytes = DetectPageSize();
|
|
||||||
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
|
|
||||||
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
|
|
||||||
if (enable_bind) {
|
|
||||||
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
|
||||||
HWY_ASSERT(page_bytes >= quantum_bytes_);
|
|
||||||
quantum_bytes_ = page_bytes;
|
|
||||||
// Ensure MaxQuantumBytes() is an upper bound.
|
|
||||||
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
|
|
||||||
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
|
|
||||||
should_bind_ = true;
|
|
||||||
} else {
|
|
||||||
HWY_WARN(
|
|
||||||
"Multiple sockets but binding disabled. This reduces speed; "
|
|
||||||
"set or remove enable_bind to avoid this warning.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0);
|
|
||||||
quantum_steps_ = quantum_bytes_ / step_bytes_;
|
|
||||||
}
|
|
||||||
|
|
||||||
Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) {
|
|
||||||
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
|
||||||
// defends against 2K aliasing.
|
|
||||||
if (!should_bind_) {
|
|
||||||
// Perf warning if Highway's alignment is less than we want.
|
|
||||||
if (HWY_ALIGNMENT < QuantumBytes()) {
|
|
||||||
HWY_WARN(
|
|
||||||
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
|
|
||||||
"are huge, enable GEMMA_BIND to avoid this warning.",
|
|
||||||
HWY_ALIGNMENT, QuantumBytes());
|
|
||||||
}
|
|
||||||
auto p = hwy::AllocateAligned<uint8_t>(bytes);
|
|
||||||
// The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the
|
|
||||||
// alignment scheme in aligned_allocator.cc and does not work for
|
|
||||||
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
|
||||||
// pointer in our own deleter.
|
|
||||||
auto call_free = [](void* ptr, size_t /*bytes*/) {
|
|
||||||
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
|
|
||||||
};
|
|
||||||
return PtrAndDeleter{p.release(), DeleterFree(call_free, bytes)};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Binding, or large vector/cache line size: use platform-specific allocator.
|
|
||||||
|
|
||||||
#if HWY_OS_LINUX && !defined(__ANDROID_API__)
|
|
||||||
// `move_pages` is documented to require an anonymous/private mapping or
|
|
||||||
// `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`.
|
|
||||||
// `Init` verified that the page size is a multiple of `QuantumBytes()`.
|
|
||||||
const int prot = PROT_READ | PROT_WRITE;
|
|
||||||
const int flags = MAP_ANONYMOUS | MAP_PRIVATE;
|
|
||||||
const int fd = -1;
|
|
||||||
// Encourage transparent hugepages by rounding up to a multiple of 2 MiB.
|
|
||||||
bytes = hwy::RoundUpTo(bytes, 2ull << 20);
|
|
||||||
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
|
||||||
if (p == MAP_FAILED) p = nullptr;
|
|
||||||
const auto call_munmap = [](void* ptr, size_t bytes) {
|
|
||||||
const int ret = munmap(ptr, bytes);
|
|
||||||
HWY_ASSERT(ret == 0);
|
|
||||||
};
|
|
||||||
return PtrAndDeleter{p, DeleterFree(call_munmap, bytes)};
|
|
||||||
#elif HWY_OS_WIN
|
|
||||||
const auto call_free = [](void* ptr, size_t) { _aligned_free(ptr); };
|
|
||||||
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
|
||||||
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
|
|
||||||
DeleterFree(call_free, bytes)};
|
|
||||||
#else
|
|
||||||
return PtrAndDeleter{nullptr, DeleterFree(nullptr, 0)};
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
#if GEMMA_BIND && HWY_OS_LINUX
|
|
||||||
|
|
||||||
using Ret = long; // NOLINT(runtime/int)
|
|
||||||
using UL = unsigned long; // NOLINT(runtime/int)
|
|
||||||
static constexpr size_t ULBits = sizeof(UL) * 8;
|
|
||||||
|
|
||||||
// Calling via syscall avoids a dependency on libnuma.
|
|
||||||
struct SyscallWrappers {
|
|
||||||
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
|
|
||||||
unsigned flags) {
|
|
||||||
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
|
|
||||||
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
|
|
||||||
};
|
|
||||||
|
|
||||||
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
|
|
||||||
int* status, int flags) {
|
|
||||||
MaybeCheckInitialized(pages, count * sizeof(void*));
|
|
||||||
MaybeCheckInitialized(nodes, count * sizeof(int));
|
|
||||||
MaybeCheckInitialized(status, count * sizeof(int));
|
|
||||||
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
|
|
||||||
}
|
|
||||||
|
|
||||||
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
|
|
||||||
unsigned flags) {
|
|
||||||
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
// Returns the number of pages that are currently busy (hence not yet moved),
|
|
||||||
// and warns if there are any other reasons for not moving a page. Note that
|
|
||||||
// `move_pages` can return 0 regardless of whether all pages were moved.
|
|
||||||
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
|
||||||
const int* status) {
|
|
||||||
size_t num_busy = 0;
|
|
||||||
for (size_t i = 0; i < num_pages; ++i) {
|
|
||||||
if (status[i] == -EBUSY) {
|
|
||||||
++num_busy;
|
|
||||||
} else if (status[i] != static_cast<int>(node)) {
|
|
||||||
static std::atomic_flag first = ATOMIC_FLAG_INIT;
|
|
||||||
if (!first.test_and_set()) {
|
|
||||||
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
|
|
||||||
status[i], i, pages[i], node, errno);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return num_busy;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) {
|
|
||||||
HWY_DASSERT(should_bind_);
|
|
||||||
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
|
||||||
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
// Ensure the requested `node` is allowed.
|
|
||||||
UL nodes[kMaxNodes / 64] = {0};
|
|
||||||
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
|
|
||||||
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
|
|
||||||
nullptr, flags) == 0);
|
|
||||||
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Avoid mbind because it does not report why it failed, which is most likely
|
|
||||||
// because pages are busy, in which case we want to know which.
|
|
||||||
|
|
||||||
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
|
|
||||||
const unsigned flags = 2; // MPOL_MF_MOVE
|
|
||||||
HWY_ASSERT(bytes % quantum_bytes_ == 0);
|
|
||||||
const size_t num_pages = bytes / quantum_bytes_;
|
|
||||||
std::vector<void*> pages;
|
|
||||||
pages.reserve(num_pages);
|
|
||||||
for (size_t i = 0; i < num_pages; ++i) {
|
|
||||||
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
|
|
||||||
// Ensure the page is faulted in to prevent `move_pages` from failing,
|
|
||||||
// because freshly allocated pages may be mapped to a shared 'zero page'.
|
|
||||||
hwy::ZeroBytes(pages.back(), 8);
|
|
||||||
}
|
|
||||||
std::vector<int> nodes(num_pages, node);
|
|
||||||
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
|
|
||||||
|
|
||||||
Ret ret = SyscallWrappers::move_pages(
|
|
||||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
|
||||||
if (ret < 0) {
|
|
||||||
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
|
|
||||||
bytes, node, errno, status[0]);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t num_busy =
|
|
||||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
|
||||||
if (HWY_UNLIKELY(num_busy != 0)) {
|
|
||||||
// Trying again is usually enough to succeed.
|
|
||||||
hwy::NanoSleep(5000);
|
|
||||||
(void)SyscallWrappers::move_pages(
|
|
||||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
|
||||||
const size_t still_busy =
|
|
||||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
|
||||||
if (HWY_UNLIKELY(still_busy != 0)) {
|
|
||||||
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
|
|
||||||
still_busy, num_busy);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
#else
|
|
||||||
bool Allocator::BindMemory(void*, size_t, size_t) { return false; }
|
|
||||||
#endif // GEMMA_BIND && HWY_OS_LINUX
|
|
||||||
|
|
||||||
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
|
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
|
||||||
line_bytes_ = DetectLineBytes();
|
line_bytes_ = DetectLineBytes();
|
||||||
vector_bytes_ = hwy::VectorBytes();
|
vector_bytes_ = hwy::VectorBytes();
|
||||||
|
|
@ -428,7 +201,7 @@ size_t Allocator2::FreeMiB() const {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
|
AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const {
|
||||||
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
||||||
// defends against 2K aliasing.
|
// defends against 2K aliasing.
|
||||||
if (!should_bind_) {
|
if (!should_bind_) {
|
||||||
|
|
@ -444,9 +217,10 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
|
||||||
// alignment scheme in aligned_allocator.cc and does not work for
|
// alignment scheme in aligned_allocator.cc and does not work for
|
||||||
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
||||||
// pointer in our own deleter.
|
// pointer in our own deleter.
|
||||||
return PtrAndDeleter{p.release(), DeleterFunc2([](void* ptr) {
|
return AlignedPtr2<uint8_t[]>(p.release(), DeleterFunc2([](void* ptr) {
|
||||||
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
|
hwy::FreeAlignedBytes(ptr, nullptr,
|
||||||
})};
|
nullptr);
|
||||||
|
}));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Binding, or large vector/cache line size: use platform-specific allocator.
|
// Binding, or large vector/cache line size: use platform-specific allocator.
|
||||||
|
|
@ -460,20 +234,126 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
|
||||||
const int fd = -1;
|
const int fd = -1;
|
||||||
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
||||||
if (p == MAP_FAILED) p = nullptr;
|
if (p == MAP_FAILED) p = nullptr;
|
||||||
return PtrAndDeleter{p, DeleterFunc2([bytes](void* ptr) {
|
return AlignedPtr2<uint8_t[]>(static_cast<uint8_t*>(p),
|
||||||
|
DeleterFunc2([bytes](void* ptr) {
|
||||||
HWY_ASSERT(munmap(ptr, bytes) == 0);
|
HWY_ASSERT(munmap(ptr, bytes) == 0);
|
||||||
})};
|
}));
|
||||||
#elif HWY_OS_WIN
|
#elif HWY_OS_WIN
|
||||||
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
||||||
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
|
return AlignedPtr2<uint8_t[]>(
|
||||||
DeleterFunc2([](void* ptr) { _aligned_free(ptr); })};
|
static_cast<uint8_t*>(_aligned_malloc(bytes, alignment)),
|
||||||
|
DeleterFunc2([](void* ptr) { _aligned_free(ptr); }));
|
||||||
#else
|
#else
|
||||||
return PtrAndDeleter{nullptr, DeleterFunc2()};
|
return AlignedPtr2<uint8_t[]>(nullptr, DeleterFunc2());
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
#if GEMMA_BIND && HWY_OS_LINUX
|
||||||
return Allocator::BindMemory(ptr, bytes, node);
|
|
||||||
|
using Ret = long; // NOLINT(runtime/int)
|
||||||
|
using UL = unsigned long; // NOLINT(runtime/int)
|
||||||
|
static constexpr size_t ULBits = sizeof(UL) * 8;
|
||||||
|
|
||||||
|
// Calling via syscall avoids a dependency on libnuma.
|
||||||
|
struct SyscallWrappers {
|
||||||
|
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
|
||||||
|
unsigned flags) {
|
||||||
|
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
|
||||||
|
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
|
||||||
|
};
|
||||||
|
|
||||||
|
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
|
||||||
|
int* status, int flags) {
|
||||||
|
MaybeCheckInitialized(pages, count * sizeof(void*));
|
||||||
|
MaybeCheckInitialized(nodes, count * sizeof(int));
|
||||||
|
MaybeCheckInitialized(status, count * sizeof(int));
|
||||||
|
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
|
||||||
|
}
|
||||||
|
|
||||||
|
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
|
||||||
|
unsigned flags) {
|
||||||
|
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Returns the number of pages that are currently busy (hence not yet moved),
|
||||||
|
// and warns if there are any other reasons for not moving a page. Note that
|
||||||
|
// `move_pages` can return 0 regardless of whether all pages were moved.
|
||||||
|
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
||||||
|
const int* status) {
|
||||||
|
size_t num_busy = 0;
|
||||||
|
for (size_t i = 0; i < num_pages; ++i) {
|
||||||
|
if (status[i] == -EBUSY) {
|
||||||
|
++num_busy;
|
||||||
|
} else if (status[i] != static_cast<int>(node)) {
|
||||||
|
static std::atomic_flag first = ATOMIC_FLAG_INIT;
|
||||||
|
if (!first.test_and_set()) {
|
||||||
|
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
|
||||||
|
status[i], i, pages[i], node, errno);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return num_busy;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
||||||
|
HWY_DASSERT(should_bind_);
|
||||||
|
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
||||||
|
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
// Ensure the requested `node` is allowed.
|
||||||
|
UL nodes[kMaxNodes / 64] = {0};
|
||||||
|
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
|
||||||
|
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
|
||||||
|
nullptr, flags) == 0);
|
||||||
|
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Avoid mbind because it does not report why it failed, which is most likely
|
||||||
|
// because pages are busy, in which case we want to know which.
|
||||||
|
|
||||||
|
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
|
||||||
|
const unsigned flags = 2; // MPOL_MF_MOVE
|
||||||
|
HWY_ASSERT(bytes % quantum_bytes_ == 0);
|
||||||
|
const size_t num_pages = bytes / quantum_bytes_;
|
||||||
|
std::vector<void*> pages;
|
||||||
|
pages.reserve(num_pages);
|
||||||
|
for (size_t i = 0; i < num_pages; ++i) {
|
||||||
|
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
|
||||||
|
// Ensure the page is faulted in to prevent `move_pages` from failing,
|
||||||
|
// because freshly allocated pages may be mapped to a shared 'zero page'.
|
||||||
|
hwy::ZeroBytes(pages.back(), 8);
|
||||||
|
}
|
||||||
|
std::vector<int> nodes(num_pages, node);
|
||||||
|
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
|
||||||
|
|
||||||
|
Ret ret = SyscallWrappers::move_pages(
|
||||||
|
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||||
|
if (ret < 0) {
|
||||||
|
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
|
||||||
|
bytes, node, errno, status[0]);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t num_busy =
|
||||||
|
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||||
|
if (HWY_UNLIKELY(num_busy != 0)) {
|
||||||
|
// Trying again is usually enough to succeed.
|
||||||
|
hwy::NanoSleep(5000);
|
||||||
|
(void)SyscallWrappers::move_pages(
|
||||||
|
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||||
|
const size_t still_busy =
|
||||||
|
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||||
|
if (HWY_UNLIKELY(still_busy != 0)) {
|
||||||
|
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
|
||||||
|
still_busy, num_busy);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
#else
|
||||||
|
bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; }
|
||||||
|
#endif // GEMMA_BIND && HWY_OS_LINUX
|
||||||
|
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
325
util/allocator.h
325
util/allocator.h
|
|
@ -30,307 +30,8 @@
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
// IWYU pragma: end_exports
|
// IWYU pragma: end_exports
|
||||||
|
|
||||||
#include "hwy/aligned_allocator.h"
|
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The
|
|
||||||
// `bytes` argument is required for the latter.
|
|
||||||
using FreeFunc = void (*)(void* mem, size_t bytes);
|
|
||||||
|
|
||||||
// Custom deleter for std::unique_ptr that calls `FreeFunc`. T is POD.
|
|
||||||
class DeleterFree {
|
|
||||||
public:
|
|
||||||
// `MatStorageT` requires this to be default-constructible.
|
|
||||||
DeleterFree() : free_func_(nullptr), bytes_(0) {}
|
|
||||||
DeleterFree(FreeFunc free_func, size_t bytes)
|
|
||||||
: free_func_(free_func), bytes_(bytes) {}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void operator()(T* p) const {
|
|
||||||
free_func_(p, bytes_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
FreeFunc free_func_;
|
|
||||||
size_t bytes_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Wrapper that also calls the destructor for non-POD T.
|
|
||||||
class DeleterDtor {
|
|
||||||
public:
|
|
||||||
DeleterDtor() {}
|
|
||||||
DeleterDtor(size_t num, DeleterFree free) : num_(num), free_(free) {}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
void operator()(T* p) const {
|
|
||||||
for (size_t i = 0; i < num_; ++i) {
|
|
||||||
p[i].~T();
|
|
||||||
}
|
|
||||||
free_(p);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
size_t num_; // not the same as free_.bytes_ / sizeof(T)!
|
|
||||||
DeleterFree free_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Unique (move-only) pointer to an aligned array of POD T.
|
|
||||||
template <typename T>
|
|
||||||
using AlignedPtr = std::unique_ptr<T[], DeleterFree>;
|
|
||||||
// Unique (move-only) pointer to an aligned array of non-POD T.
|
|
||||||
template <typename T>
|
|
||||||
using AlignedClassPtr = std::unique_ptr<T[], DeleterDtor>;
|
|
||||||
|
|
||||||
// Both allocation, binding, and row accessors depend on the sizes of memory
|
|
||||||
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
|
|
||||||
// use `Monostate` (static members).
|
|
||||||
class Allocator {
|
|
||||||
public:
|
|
||||||
// Must be called at least once before any other function. Not thread-safe,
|
|
||||||
// hence only call this from the main thread.
|
|
||||||
// TODO: remove enable_bind once Gemma tensors support binding.
|
|
||||||
static void Init(const BoundedTopology& topology, bool enable_bind = false);
|
|
||||||
|
|
||||||
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
|
||||||
// ranges such that there will be no false sharing.
|
|
||||||
static size_t LineBytes();
|
|
||||||
// Bytes per full vector. Used to compute loop steps.
|
|
||||||
static size_t VectorBytes();
|
|
||||||
// Work granularity that avoids false sharing and partial vectors.
|
|
||||||
static size_t StepBytes(); // = HWY_MAX(LineBytes(), VectorBytes())
|
|
||||||
// Granularity like `StepBytes()`, but when NUMA may be involved.
|
|
||||||
static size_t QuantumBytes();
|
|
||||||
// Upper bound on `QuantumBytes()`, for stack allocations.
|
|
||||||
static constexpr size_t MaxQuantumBytes() { return 4096; }
|
|
||||||
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
|
|
||||||
|
|
||||||
// L1 and L2 are typically per core.
|
|
||||||
static size_t L1Bytes();
|
|
||||||
static size_t L2Bytes();
|
|
||||||
// Clusters often share an L3. We return the total size per package.
|
|
||||||
static size_t L3Bytes();
|
|
||||||
|
|
||||||
// Returns pointer aligned to `QuantumBytes()`.
|
|
||||||
template <typename T>
|
|
||||||
static AlignedPtr<T> Alloc(size_t num) {
|
|
||||||
constexpr size_t kSize = sizeof(T);
|
|
||||||
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
|
|
||||||
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
|
|
||||||
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
|
|
||||||
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
|
|
||||||
// Fail if the `bytes = num * kSize` computation overflowed.
|
|
||||||
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
|
|
||||||
if (check != num) return AlignedPtr<T>();
|
|
||||||
|
|
||||||
PtrAndDeleter pd = AllocBytes(bytes);
|
|
||||||
return AlignedPtr<T>(static_cast<T*>(pd.p), pd.deleter);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Same as Alloc, but calls constructor(s) with `args`.
|
|
||||||
template <typename T, class... Args>
|
|
||||||
static AlignedClassPtr<T> AllocClasses(size_t num, Args&&... args) {
|
|
||||||
constexpr size_t kSize = sizeof(T);
|
|
||||||
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
|
|
||||||
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
|
|
||||||
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
|
|
||||||
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
|
|
||||||
// Fail if the `bytes = num * kSize` computation overflowed.
|
|
||||||
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
|
|
||||||
if (check != num) return AlignedClassPtr<T>();
|
|
||||||
|
|
||||||
PtrAndDeleter pd = AllocBytes(bytes);
|
|
||||||
T* p = static_cast<T*>(pd.p);
|
|
||||||
for (size_t i = 0; i < num; ++i) {
|
|
||||||
new (p + i) T(std::forward<Args>(args)...);
|
|
||||||
}
|
|
||||||
return AlignedClassPtr<T>(p, DeleterDtor(num, pd.deleter));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
|
||||||
// control over memory placement and multiple packages and NUMA nodes.
|
|
||||||
static bool ShouldBind();
|
|
||||||
|
|
||||||
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
|
|
||||||
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
|
|
||||||
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
|
|
||||||
// `p` and `bytes` must be multiples of `QuantumBytes()`.
|
|
||||||
static bool BindMemory(void* p, size_t bytes, size_t node);
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Type-erased so this can be implemented in allocator.cc.
|
|
||||||
struct PtrAndDeleter {
|
|
||||||
void* p;
|
|
||||||
DeleterFree deleter;
|
|
||||||
};
|
|
||||||
static PtrAndDeleter AllocBytes(size_t bytes);
|
|
||||||
};
|
|
||||||
|
|
||||||
// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets"
|
|
||||||
// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is
|
|
||||||
// typically 4KiB. To avoid remote accesses, we would thus pad each row to that,
|
|
||||||
// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able
|
|
||||||
// to prevent that by pulling rows forward by a cyclic offset, which is still a
|
|
||||||
// multiple of the cache line size. This requires an additional
|
|
||||||
// `Allocator::QuantumBytes()` of padding after also rounding up to that.
|
|
||||||
template <typename T>
|
|
||||||
constexpr size_t StrideForCyclicOffsets(size_t cols) {
|
|
||||||
const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T);
|
|
||||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
|
||||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
|
||||||
// the memory.
|
|
||||||
template <typename T>
|
|
||||||
class RowVectorBatch {
|
|
||||||
public:
|
|
||||||
// Default ctor for Activations ctor.
|
|
||||||
RowVectorBatch() = default;
|
|
||||||
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
|
||||||
// we default to tightly packed rows (`stride = cols`).
|
|
||||||
// WARNING: not all call sites support `stride` != cols.
|
|
||||||
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
|
||||||
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
|
|
||||||
if (stride == 0) {
|
|
||||||
stride_ = extents_.cols;
|
|
||||||
} else {
|
|
||||||
HWY_ASSERT(stride >= extents_.cols);
|
|
||||||
stride_ = stride;
|
|
||||||
}
|
|
||||||
// Allow binding the entire matrix.
|
|
||||||
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
|
|
||||||
Allocator::QuantumBytes() / sizeof(T));
|
|
||||||
mem_ = Allocator::Alloc<T>(padded);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move-only
|
|
||||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
|
||||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
|
||||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
|
||||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
|
||||||
|
|
||||||
size_t BatchSize() const { return extents_.rows; }
|
|
||||||
size_t Cols() const { return extents_.cols; }
|
|
||||||
size_t Stride() const { return stride_; }
|
|
||||||
Extents2D Extents() const { return extents_; }
|
|
||||||
|
|
||||||
// Returns the given row vector of length `Cols()`.
|
|
||||||
T* Batch(size_t batch_idx) {
|
|
||||||
HWY_DASSERT(batch_idx < BatchSize());
|
|
||||||
return mem_.get() + batch_idx * stride_;
|
|
||||||
}
|
|
||||||
const T* Batch(size_t batch_idx) const {
|
|
||||||
HWY_DASSERT(batch_idx < BatchSize());
|
|
||||||
return mem_.get() + batch_idx * stride_;
|
|
||||||
}
|
|
||||||
|
|
||||||
// For MatMul or other operations that process the entire batch at once.
|
|
||||||
// TODO: remove once we only use Mat.
|
|
||||||
T* All() { return mem_.get(); }
|
|
||||||
const T* Const() const { return mem_.get(); }
|
|
||||||
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
|
|
||||||
|
|
||||||
private:
|
|
||||||
AlignedPtr<T> mem_;
|
|
||||||
Extents2D extents_;
|
|
||||||
size_t stride_;
|
|
||||||
};
|
|
||||||
|
|
||||||
// Returns `num` rounded up to an odd number of cache lines. This is used to
|
|
||||||
// compute strides. An odd number of cache lines prevents 2K aliasing and is
|
|
||||||
// coprime with the cache associativity, which reduces conflict misses.
|
|
||||||
template <typename T>
|
|
||||||
static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
|
|
||||||
HWY_DASSERT(line_bytes >= 32);
|
|
||||||
HWY_DASSERT(line_bytes % sizeof(T) == 0);
|
|
||||||
const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes);
|
|
||||||
const size_t padded_num = (lines | 1) * line_bytes / sizeof(T);
|
|
||||||
HWY_DASSERT(padded_num >= num);
|
|
||||||
return padded_num;
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
|
||||||
RowVectorBatch<T> AllocateAlignedRows(Extents2D extents) {
|
|
||||||
return RowVectorBatch<T>(extents, StrideForCyclicOffsets<T>(extents.cols));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
|
||||||
// it is always float and does not support compressed T, but does support an
|
|
||||||
// arbitrary stride >= cols.
|
|
||||||
#pragma pack(push, 1) // power of two size
|
|
||||||
template <typename T>
|
|
||||||
class RowPtr {
|
|
||||||
public:
|
|
||||||
RowPtr() = default; // for `MMPtrs`.
|
|
||||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
|
||||||
: row0_(row0),
|
|
||||||
stride_(stride),
|
|
||||||
step_(static_cast<uint32_t>(Allocator::StepBytes())),
|
|
||||||
cols_(static_cast<uint32_t>(cols)),
|
|
||||||
row_mask_(Allocator::QuantumSteps() - 1) {
|
|
||||||
HWY_DASSERT(stride >= cols);
|
|
||||||
HWY_DASSERT(row_mask_ != ~size_t{0});
|
|
||||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
|
||||||
if (stride < StrideForCyclicOffsets<T>(cols)) {
|
|
||||||
static bool once;
|
|
||||||
if (!once) {
|
|
||||||
once = true;
|
|
||||||
HWY_WARN(
|
|
||||||
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
|
|
||||||
"T=%zu; this forces us to disable cyclic offsets.",
|
|
||||||
stride, cols, sizeof(T));
|
|
||||||
}
|
|
||||||
row_mask_ = 0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
|
||||||
|
|
||||||
T* HWY_RESTRICT Row(size_t r) const {
|
|
||||||
// How much of the previous row's padding to consume.
|
|
||||||
const size_t pad_bytes = (r & row_mask_) * step_;
|
|
||||||
HWY_DASSERT(pad_bytes < Allocator::QuantumBytes());
|
|
||||||
return row0_ + stride_ * r - pad_bytes;
|
|
||||||
}
|
|
||||||
size_t Cols() const { return cols_; }
|
|
||||||
|
|
||||||
size_t Stride() const { return stride_; }
|
|
||||||
void SetStride(size_t stride) {
|
|
||||||
HWY_DASSERT(stride >= Cols());
|
|
||||||
stride_ = stride;
|
|
||||||
// The caller might not have padded enough, so disable the padding in Row().
|
|
||||||
// Rows will now be exactly `stride` elements apart. This is used when
|
|
||||||
// writing to the KV cache via MatMul.
|
|
||||||
row_mask_ = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
|
||||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
|
||||||
HWY_DASSERT(c < cols_);
|
|
||||||
HWY_DASSERT(cols <= cols_ - c);
|
|
||||||
return RowPtr<T>(Row(r) + c, cols, stride_);
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
T* HWY_RESTRICT row0_;
|
|
||||||
size_t stride_;
|
|
||||||
uint32_t step_; // Copy from Allocator::LineBytes() to improve locality.
|
|
||||||
uint32_t cols_;
|
|
||||||
size_t row_mask_;
|
|
||||||
};
|
|
||||||
#pragma pack(pop)
|
|
||||||
|
|
||||||
using RowPtrBF = RowPtr<BF16>;
|
|
||||||
using RowPtrF = RowPtr<float>;
|
|
||||||
using RowPtrD = RowPtr<double>;
|
|
||||||
|
|
||||||
// For C argument to MatMul.
|
|
||||||
template <typename T>
|
|
||||||
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
|
|
||||||
return RowPtr<T>(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Custom deleter for types without a dtor, but where the deallocation requires
|
// Custom deleter for types without a dtor, but where the deallocation requires
|
||||||
// state, e.g. a lambda with *by-value* capture.
|
// state, e.g. a lambda with *by-value* capture.
|
||||||
class DeleterFunc2 {
|
class DeleterFunc2 {
|
||||||
|
|
@ -420,15 +121,22 @@ class Allocator2 {
|
||||||
size_t TotalMiB() const { return total_mib_; }
|
size_t TotalMiB() const { return total_mib_; }
|
||||||
size_t FreeMiB() const;
|
size_t FreeMiB() const;
|
||||||
|
|
||||||
// Returns pointer aligned to `QuantumBytes()`.
|
// Returns byte pointer aligned to `QuantumBytes()`, without calling
|
||||||
|
// constructors nor destructors on deletion. Type-erased so this can be
|
||||||
|
// implemented in `allocator.cc` and called by `MatOwner`.
|
||||||
|
AlignedPtr2<uint8_t[]> AllocBytes(size_t bytes) const;
|
||||||
|
|
||||||
|
// Returns pointer aligned to `QuantumBytes()`, without calling constructors
|
||||||
|
// nor destructors on deletion.
|
||||||
template <typename T>
|
template <typename T>
|
||||||
AlignedPtr2<T[]> Alloc(size_t num) const {
|
AlignedPtr2<T[]> Alloc(size_t num) const {
|
||||||
const size_t bytes = num * sizeof(T);
|
const size_t bytes = num * sizeof(T);
|
||||||
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
||||||
HWY_ASSERT(bytes / sizeof(T) == num);
|
HWY_ASSERT(bytes / sizeof(T) == num);
|
||||||
|
|
||||||
PtrAndDeleter pd = AllocBytes(bytes);
|
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
|
||||||
return AlignedPtr2<T[]>(static_cast<T*>(pd.p), pd.deleter);
|
return AlignedPtr2<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()),
|
||||||
|
p8.get_deleter());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Same as Alloc, but calls constructor(s) with `args` and the deleter will
|
// Same as Alloc, but calls constructor(s) with `args` and the deleter will
|
||||||
|
|
@ -439,12 +147,12 @@ class Allocator2 {
|
||||||
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
||||||
HWY_ASSERT(bytes / sizeof(T) == num);
|
HWY_ASSERT(bytes / sizeof(T) == num);
|
||||||
|
|
||||||
PtrAndDeleter pd = AllocBytes(bytes);
|
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
|
||||||
T* p = static_cast<T*>(pd.p);
|
T* p = HWY_RCAST_ALIGNED(T*, p8.release());
|
||||||
for (size_t i = 0; i < num; ++i) {
|
for (size_t i = 0; i < num; ++i) {
|
||||||
new (p + i) T(std::forward<Args>(args)...);
|
new (p + i) T(std::forward<Args>(args)...);
|
||||||
}
|
}
|
||||||
return AlignedClassPtr2<T>(p, DeleterDtor2(num, pd.deleter));
|
return AlignedClassPtr2<T>(p, DeleterDtor2(num, p8.get_deleter()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
||||||
|
|
@ -458,13 +166,6 @@ class Allocator2 {
|
||||||
bool BindMemory(void* p, size_t bytes, size_t node) const;
|
bool BindMemory(void* p, size_t bytes, size_t node) const;
|
||||||
|
|
||||||
private:
|
private:
|
||||||
// Type-erased so this can be implemented in allocator.cc.
|
|
||||||
struct PtrAndDeleter {
|
|
||||||
void* p;
|
|
||||||
DeleterFunc2 deleter;
|
|
||||||
};
|
|
||||||
PtrAndDeleter AllocBytes(size_t bytes) const;
|
|
||||||
|
|
||||||
size_t line_bytes_;
|
size_t line_bytes_;
|
||||||
size_t vector_bytes_;
|
size_t vector_bytes_;
|
||||||
size_t step_bytes_;
|
size_t step_bytes_;
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,7 @@
|
||||||
#include <algorithm> // std::transform
|
#include <algorithm> // std::transform
|
||||||
#include <string>
|
#include <string>
|
||||||
|
|
||||||
#include "compression/io.h"
|
#include "compression/io.h" // Path
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "hwy/base.h" // HWY_ABORT
|
#include "hwy/base.h" // HWY_ABORT
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,100 @@
|
||||||
|
// Copyright 2024 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// https://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "util/mat.h"
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include "util/threading_context.h"
|
||||||
|
#include "hwy/base.h"
|
||||||
|
#include "hwy/per_target.h" // VectorBytes
|
||||||
|
#include "hwy/profiler.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
void CopyMat(const MatPtr& from, MatPtr& to) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols());
|
||||||
|
HWY_ASSERT(to.GetType() == from.GetType());
|
||||||
|
if (to.IsPacked() && from.IsPacked()) {
|
||||||
|
HWY_ASSERT(to.PackedBytes() == from.PackedBytes());
|
||||||
|
hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t row_bytes = to.Cols() * to.ElementBytes();
|
||||||
|
for (size_t r = 0; r < to.Rows(); ++r) {
|
||||||
|
const uint8_t* from_row = from.RowT<uint8_t>(r);
|
||||||
|
uint8_t* to_row = to.RowT<uint8_t>(r);
|
||||||
|
hwy::CopyBytes(from_row, to_row, row_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void ZeroInit(MatPtr& mat) {
|
||||||
|
PROFILER_FUNC;
|
||||||
|
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
|
||||||
|
if (mat.IsPacked()) {
|
||||||
|
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
|
||||||
|
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||||
|
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns `num` rounded up to an odd number of cache lines. This would also
|
||||||
|
// prevent 4K aliasing and is coprime with the cache associativity, which
|
||||||
|
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
|
||||||
|
static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
|
||||||
|
size_t element_bytes) {
|
||||||
|
HWY_DASSERT(line_bytes >= 32);
|
||||||
|
HWY_DASSERT(line_bytes % element_bytes == 0);
|
||||||
|
const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes);
|
||||||
|
const size_t padded_num = (lines | 1) * line_bytes / element_bytes;
|
||||||
|
HWY_DASSERT(padded_num >= num);
|
||||||
|
return padded_num;
|
||||||
|
}
|
||||||
|
|
||||||
|
static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
|
||||||
|
MatPadding padding) {
|
||||||
|
switch (padding) {
|
||||||
|
case MatPadding::kPacked:
|
||||||
|
default:
|
||||||
|
return mat.Cols();
|
||||||
|
case MatPadding::kOdd:
|
||||||
|
return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(),
|
||||||
|
mat.ElementBytes());
|
||||||
|
case MatPadding::kCyclic:
|
||||||
|
return StrideForCyclicOffsets(
|
||||||
|
mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||||
|
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||||
|
const size_t stride = Stride(allocator, mat, padding);
|
||||||
|
const size_t num = mat.Rows() * stride;
|
||||||
|
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
||||||
|
// might not be enough, hence add extra. `MatT` is at least one byte, which
|
||||||
|
// is half of BF16, hence adding `VectorBytes` *elements* is enough.
|
||||||
|
const size_t bytes = (num + hwy::VectorBytes()) * mat.ElementBytes();
|
||||||
|
// Allow binding the entire matrix.
|
||||||
|
const size_t padded_bytes =
|
||||||
|
hwy::RoundUpTo(bytes, allocator.QuantumBytes() / mat.ElementBytes());
|
||||||
|
storage_ = allocator.AllocBytes(padded_bytes);
|
||||||
|
mat.SetPtr(storage_.get(), stride);
|
||||||
|
}
|
||||||
|
} // namespace gcpp
|
||||||
|
|
@ -0,0 +1,532 @@
|
||||||
|
// Copyright 2023 Google LLC
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
// you may not use this file except in compliance with the License.
|
||||||
|
// You may obtain a copy of the License at
|
||||||
|
//
|
||||||
|
// http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
//
|
||||||
|
// Unless required by applicable law or agreed to in writing, software
|
||||||
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
// See the License for the specific language governing permissions and
|
||||||
|
// limitations under the License.
|
||||||
|
|
||||||
|
// Tensor metadata and in-memory representation.
|
||||||
|
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||||
|
#define THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||||
|
|
||||||
|
#include <stddef.h>
|
||||||
|
#include <stdint.h>
|
||||||
|
|
||||||
|
#include <random>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
|
// IWYU pragma: begin_exports
|
||||||
|
#include "compression/fields.h"
|
||||||
|
#include "compression/shared.h" // Type
|
||||||
|
#include "gemma/tensor_index.h"
|
||||||
|
#include "util/allocator.h"
|
||||||
|
#include "util/basics.h" // Extents2D
|
||||||
|
// IWYU pragma: end_exports
|
||||||
|
#include "hwy/base.h"
|
||||||
|
|
||||||
|
namespace gcpp {
|
||||||
|
|
||||||
|
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
||||||
|
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
|
||||||
|
// to store hetereogeneous tensor references in a vector.
|
||||||
|
//
|
||||||
|
// Copyable, (de)serializable via `fields.h` for `model_store.h`.
|
||||||
|
class MatPtr : public IFields {
|
||||||
|
public:
|
||||||
|
MatPtr() = default;
|
||||||
|
// `name`: see `SetName`. Note that `stride` is initially `cols` and only
|
||||||
|
// differs after deserializing, or calling `SetPtr`.
|
||||||
|
MatPtr(const char* name, Type type, Extents2D extents)
|
||||||
|
: rows_(static_cast<uint32_t>(extents.rows)),
|
||||||
|
cols_(static_cast<uint32_t>(extents.cols)) {
|
||||||
|
SetName(name);
|
||||||
|
SetType(type);
|
||||||
|
SetPtr(nullptr, cols_);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Copying allowed because the metadata is small.
|
||||||
|
MatPtr(const MatPtr& other) = default;
|
||||||
|
MatPtr& operator=(const MatPtr& other) = default;
|
||||||
|
|
||||||
|
virtual ~MatPtr() = default;
|
||||||
|
|
||||||
|
// Only for use by ctor, `AllocateFor` and 'loading' memory-mapped tensors.
|
||||||
|
void SetPtr(void* ptr, size_t stride) {
|
||||||
|
HWY_ASSERT(stride >= Cols());
|
||||||
|
ptr_ = ptr;
|
||||||
|
stride_ = static_cast<uint32_t>(stride);
|
||||||
|
|
||||||
|
// NUQ streams must not be padded because that would change the position of
|
||||||
|
// the group tables.
|
||||||
|
if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool HasPtr() const { return ptr_ != nullptr; }
|
||||||
|
|
||||||
|
bool IsPacked() const { return stride_ == cols_; }
|
||||||
|
|
||||||
|
const void* Packed() const {
|
||||||
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||||
|
return ptr_;
|
||||||
|
}
|
||||||
|
void* Packed() {
|
||||||
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||||
|
return ptr_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns size in bytes for purposes of copying/initializing or I/O. Must
|
||||||
|
// only be called if `IsPacked`.
|
||||||
|
size_t PackedBytes() const {
|
||||||
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||||
|
// num_elements_ already includes the NUQ tables.
|
||||||
|
return num_elements_ * element_bytes_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Works for any kind of padding.
|
||||||
|
template <typename T>
|
||||||
|
T* MutableRowT(size_t row) const {
|
||||||
|
HWY_DASSERT(row < rows_);
|
||||||
|
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
T* RowT(size_t row) {
|
||||||
|
HWY_DASSERT(row < rows_);
|
||||||
|
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||||
|
}
|
||||||
|
template <typename T>
|
||||||
|
const T* RowT(size_t row) const {
|
||||||
|
HWY_DASSERT(row < rows_);
|
||||||
|
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
Type GetType() const { return type_; }
|
||||||
|
void SetType(Type type) {
|
||||||
|
type_ = type;
|
||||||
|
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
|
||||||
|
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; }
|
||||||
|
size_t Rows() const { return rows_; }
|
||||||
|
size_t Cols() const { return cols_; }
|
||||||
|
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||||
|
|
||||||
|
// Offset by which to advance pointers to the next row.
|
||||||
|
size_t Stride() const { return stride_; }
|
||||||
|
|
||||||
|
// For use by `BlobStore`, `CopyMat` and `ZeroInit`.
|
||||||
|
size_t ElementBytes() const { return element_bytes_; }
|
||||||
|
|
||||||
|
// Decoded elements should be multiplied by this to restore their original
|
||||||
|
// range. This is required because `SfpStream` can only encode a limited range
|
||||||
|
// of magnitudes.
|
||||||
|
float Scale() const { return scale_; }
|
||||||
|
void SetScale(float scale) { scale_ = scale; }
|
||||||
|
|
||||||
|
// Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it
|
||||||
|
// be <= 16 bytes including prefixes/suffixes. The initial name set by the
|
||||||
|
// ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer
|
||||||
|
// suffix, and when loading, we call `SetName` with that.
|
||||||
|
const char* Name() const override { return name_.c_str(); }
|
||||||
|
void SetName(const char* name) {
|
||||||
|
name_ = name;
|
||||||
|
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
|
||||||
|
}
|
||||||
|
|
||||||
|
void VisitFields(IFieldsVisitor& visitor) override {
|
||||||
|
// Order determines the order of serialization and must not change.
|
||||||
|
visitor(name_);
|
||||||
|
visitor(type_);
|
||||||
|
visitor(element_bytes_);
|
||||||
|
visitor(num_elements_);
|
||||||
|
visitor(rows_);
|
||||||
|
visitor(cols_);
|
||||||
|
visitor(scale_);
|
||||||
|
visitor(stride_);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
// For initializing `num_elements_`: "elements" are how many objects we
|
||||||
|
// actually store in order to represent rows * cols values. For NUQ, this is
|
||||||
|
// greater because it includes additional per-group tables. This is the only
|
||||||
|
// place where we compute this fixup. Note that elements are independent of
|
||||||
|
// padding, which is anyway not supported for NUQ because `compress-inl.h`
|
||||||
|
// assumes a contiguous stream for its group indexing.
|
||||||
|
static size_t ComputeNumElements(Type type, Extents2D extents) {
|
||||||
|
const size_t num_elements = extents.Area();
|
||||||
|
if (type == Type::kNUQ) {
|
||||||
|
// `CompressedArrayElements` is a wrapper function that has the same
|
||||||
|
// effect, but that requires a template argument, not `type`.
|
||||||
|
return NuqStream::PackedEnd(num_elements);
|
||||||
|
}
|
||||||
|
return num_elements;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string name_; // See `SetName`.
|
||||||
|
Type type_;
|
||||||
|
|
||||||
|
// Most members are u32 because that is the preferred type of fields.h.
|
||||||
|
|
||||||
|
// Bytes per element. This is fully determined by `type_`, but stored here
|
||||||
|
// for convenience and backward compatibility.
|
||||||
|
uint32_t element_bytes_ = 0;
|
||||||
|
// Number of elements to store (including NUQ tables but not padding).
|
||||||
|
// This a function of `type_` and `Extents()` and stored for compatibility.
|
||||||
|
uint32_t num_elements_ = 0;
|
||||||
|
uint32_t rows_ = 0;
|
||||||
|
uint32_t cols_ = 0;
|
||||||
|
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
||||||
|
|
||||||
|
// Non-owning pointer, must not be freed. The underlying memory must outlive
|
||||||
|
// this object.
|
||||||
|
void* ptr_ = nullptr; // not serialized
|
||||||
|
|
||||||
|
// Offset by which to advance pointers to the next row, >= `cols_`.
|
||||||
|
uint32_t stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Non-type erased version of `MatPtr`. Use this when operating on the values.
|
||||||
|
template <typename MatT>
|
||||||
|
class MatPtrT : public MatPtr {
|
||||||
|
public:
|
||||||
|
// Runtime-specified shape.
|
||||||
|
MatPtrT(const char* name, Extents2D extents)
|
||||||
|
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||||
|
// Take shape from `TensorInfo` to avoid duplicating it in the caller.
|
||||||
|
MatPtrT(const char* name, const TensorInfo* tensor)
|
||||||
|
: MatPtrT<MatT>(name, ExtentsFromInfo(tensor)) {}
|
||||||
|
// Find `TensorInfo` by name in `TensorIndex`.
|
||||||
|
MatPtrT(const char* name, const TensorIndex& tensor_index)
|
||||||
|
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||||
|
|
||||||
|
// Copying allowed because the metadata is small.
|
||||||
|
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||||
|
MatPtrT& operator=(const MatPtr& other) {
|
||||||
|
MatPtr::operator=(other);
|
||||||
|
return *this;
|
||||||
|
}
|
||||||
|
MatPtrT(const MatPtrT& other) = default;
|
||||||
|
MatPtrT& operator=(const MatPtrT& other) = default;
|
||||||
|
|
||||||
|
// Returns the entire tensor for use by `backprop/*`. Verifies layout is
|
||||||
|
// `kPacked`. Preferably call `Row` instead, which works for either layout.
|
||||||
|
MatT* Packed() {
|
||||||
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||||
|
return HWY_RCAST_ALIGNED(MatT*, ptr_);
|
||||||
|
}
|
||||||
|
const MatT* Packed() const {
|
||||||
|
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||||
|
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
||||||
|
}
|
||||||
|
// As `Packed()`, plus checks the scale is 1.0 because callers will ignore it.
|
||||||
|
// This is typically used for `MatMul` bias vectors and norm weights.
|
||||||
|
const MatT* PackedScale1() const {
|
||||||
|
HWY_DASSERT(Scale() == 1.0f);
|
||||||
|
return Packed();
|
||||||
|
}
|
||||||
|
|
||||||
|
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
|
||||||
|
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
|
||||||
|
|
||||||
|
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
||||||
|
// require packed layout.
|
||||||
|
PackedSpan<const MatT> Span() const {
|
||||||
|
HWY_ASSERT(IsPacked());
|
||||||
|
return MakeConstSpan(Row(0), num_elements_);
|
||||||
|
}
|
||||||
|
PackedSpan<MatT> Span() {
|
||||||
|
HWY_ASSERT(IsPacked());
|
||||||
|
return MakeSpan(Row(0), num_elements_);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||||
|
// optional `args`.
|
||||||
|
template <class Func, typename... Args>
|
||||||
|
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
||||||
|
Args&&... args) {
|
||||||
|
HWY_ASSERT(base != nullptr);
|
||||||
|
if (type == Type::kF32) {
|
||||||
|
return func(dynamic_cast<MatPtrT<float>*>(base),
|
||||||
|
std::forward<Args>(args)...);
|
||||||
|
} else if (type == Type::kBF16) {
|
||||||
|
return func(dynamic_cast<MatPtrT<BF16>*>(base),
|
||||||
|
std::forward<Args>(args)...);
|
||||||
|
} else if (type == Type::kSFP) {
|
||||||
|
return func(dynamic_cast<MatPtrT<SfpStream>*>(base),
|
||||||
|
std::forward<Args>(args)...);
|
||||||
|
} else if (type == Type::kNUQ) {
|
||||||
|
return func(dynamic_cast<MatPtrT<NuqStream>*>(base),
|
||||||
|
std::forward<Args>(args)...);
|
||||||
|
} else {
|
||||||
|
HWY_ABORT("Type %d unknown.", static_cast<int>(type));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||||
|
void ZeroInit(MatPtr& mat);
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
||||||
|
std::normal_distribution<T> dist(0.0, stddev);
|
||||||
|
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||||
|
T* row = x.Row(r);
|
||||||
|
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||||
|
row[c] = dist(gen);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
|
||||||
|
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB.
|
||||||
|
// To avoid remote accesses, we would thus pad each row to that, which results
|
||||||
|
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
|
||||||
|
// by pulling rows forward by a cyclic offset, which is still a multiple of the
|
||||||
|
// cache line size. This requires an additional `Allocator2::QuantumBytes()` of
|
||||||
|
// padding after also rounding up to that, which considerably increases size for
|
||||||
|
// tall and skinny tensors.
|
||||||
|
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
|
||||||
|
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||||
|
}
|
||||||
|
// Constexpr version (upper bound) for allocating storage in MatMul.
|
||||||
|
template <typename T>
|
||||||
|
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
|
||||||
|
constexpr size_t quantum = Allocator2::MaxQuantum<T>();
|
||||||
|
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Our tensors are always row-major. This enum indicates how much (if any)
|
||||||
|
// padding comes after each row.
|
||||||
|
enum class MatPadding {
|
||||||
|
// None, stride == cols. `compress-inl.h` requires this layout because its
|
||||||
|
// interface assumes a continuous 1D array, without awareness of rows. Note
|
||||||
|
// that tensors which were written via `compress-inl.h` (i.e. most in
|
||||||
|
// `BlobStore`) are not padded, which also extends to memory-mapped tensors.
|
||||||
|
// However, `BlobStore` is able to insert padding via row-wise I/O when
|
||||||
|
// reading from disk via `Mode::kRead`.
|
||||||
|
//
|
||||||
|
// `backprop/*` also requires this layout because it indexes directly into
|
||||||
|
// the storage instead of calling `Row()`.
|
||||||
|
kPacked,
|
||||||
|
// Enough to round up to an odd number of cache lines, which can reduce
|
||||||
|
// cache conflict misses or 4K aliasing.
|
||||||
|
kOdd,
|
||||||
|
// Enough to enable the "cyclic offsets" optimization for `MatMul`.
|
||||||
|
kCyclic,
|
||||||
|
};
|
||||||
|
|
||||||
|
// Type-erased, allows storing `AlignedPtr2<T[]>` for various T in the same
|
||||||
|
// vector.
|
||||||
|
class MatOwner {
|
||||||
|
public:
|
||||||
|
MatOwner() = default;
|
||||||
|
// Allow move for `MatStorageT`.
|
||||||
|
MatOwner(MatOwner&&) = default;
|
||||||
|
MatOwner& operator=(MatOwner&&) = default;
|
||||||
|
|
||||||
|
// Allocates the type/extents indicated by `mat` and sets its pointer.
|
||||||
|
void AllocateFor(MatPtr& mat, MatPadding padding);
|
||||||
|
|
||||||
|
private:
|
||||||
|
AlignedPtr2<uint8_t[]> storage_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
|
||||||
|
// tests to allocate and access tensors of a known type. By contrast, the
|
||||||
|
// heterogeneous model weights are owned by vectors of `MatOwner`.
|
||||||
|
template <typename MatT>
|
||||||
|
class MatStorageT : public MatPtrT<MatT> {
|
||||||
|
public:
|
||||||
|
MatStorageT(const char* name, Extents2D extents, MatPadding padding)
|
||||||
|
: MatPtrT<MatT>(name, extents) {
|
||||||
|
owner_.AllocateFor(*this, padding);
|
||||||
|
}
|
||||||
|
~MatStorageT() = default;
|
||||||
|
|
||||||
|
// Allow move for backprop/activations.
|
||||||
|
MatStorageT(MatStorageT&&) = default;
|
||||||
|
MatStorageT& operator=(MatStorageT&&) = default;
|
||||||
|
|
||||||
|
private:
|
||||||
|
MatOwner owner_;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Helper factory function for use by `backprop/` to avoid specifying the
|
||||||
|
// `MatPadding` argument everywhere.
|
||||||
|
template <typename T>
|
||||||
|
MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
|
||||||
|
return MatStorageT<T>(name, Extents2D(rows, cols), MatPadding::kPacked);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
|
||||||
|
// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets.
|
||||||
|
#pragma pack(push, 1) // power of two size
|
||||||
|
template <typename T>
|
||||||
|
class RowPtr {
|
||||||
|
public:
|
||||||
|
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols,
|
||||||
|
size_t stride)
|
||||||
|
: row0_(row0),
|
||||||
|
stride_(stride),
|
||||||
|
row_mask_(
|
||||||
|
static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
|
||||||
|
cols_(static_cast<uint32_t>(cols)),
|
||||||
|
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())),
|
||||||
|
quantum_bytes_(allocator.QuantumBytes()) {
|
||||||
|
HWY_DASSERT(stride >= cols);
|
||||||
|
HWY_DASSERT(row_mask_ != ~uint32_t{0});
|
||||||
|
if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) {
|
||||||
|
row_mask_ = 0;
|
||||||
|
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||||
|
static bool once;
|
||||||
|
if (stride != cols && !once) {
|
||||||
|
once = true;
|
||||||
|
HWY_WARN(
|
||||||
|
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
|
||||||
|
"T=%zu; this forces us to disable cyclic offsets.",
|
||||||
|
stride, cols, sizeof(T));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols)
|
||||||
|
: RowPtr(allocator, row0, cols, cols) {}
|
||||||
|
|
||||||
|
T* HWY_RESTRICT Row(size_t r) const {
|
||||||
|
// How much of the previous row's padding to consume.
|
||||||
|
const size_t pad_bytes = (r & row_mask_) * step_bytes_;
|
||||||
|
HWY_DASSERT(pad_bytes < static_cast<size_t>(quantum_bytes_));
|
||||||
|
return row0_ + stride_ * r - pad_bytes;
|
||||||
|
}
|
||||||
|
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||||
|
|
||||||
|
size_t Stride() const { return stride_; }
|
||||||
|
void SetStride(size_t stride) {
|
||||||
|
HWY_DASSERT(stride >= Cols());
|
||||||
|
stride_ = stride;
|
||||||
|
// The caller might not have padded enough, so disable the padding in Row().
|
||||||
|
// Rows will now be exactly `stride` elements apart. This is used when
|
||||||
|
// writing to the KV cache via MatMul.
|
||||||
|
row_mask_ = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||||
|
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||||
|
HWY_DASSERT(c < Cols());
|
||||||
|
HWY_DASSERT(cols <= Cols() - c);
|
||||||
|
return RowPtr<T>(Row(r) + c, cols, stride_, row_mask_, step_bytes_,
|
||||||
|
quantum_bytes_);
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// For `View()`.
|
||||||
|
RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask,
|
||||||
|
uint32_t step_bytes, uint32_t quantum_bytes)
|
||||||
|
: row0_(new_row0),
|
||||||
|
stride_(stride),
|
||||||
|
row_mask_(row_mask),
|
||||||
|
cols_(new_cols),
|
||||||
|
step_bytes_(step_bytes),
|
||||||
|
quantum_bytes_(quantum_bytes) {}
|
||||||
|
|
||||||
|
T* HWY_RESTRICT row0_;
|
||||||
|
size_t stride_;
|
||||||
|
uint32_t row_mask_;
|
||||||
|
uint32_t cols_;
|
||||||
|
uint32_t step_bytes_;
|
||||||
|
uint32_t quantum_bytes_;
|
||||||
|
};
|
||||||
|
#pragma pack(pop)
|
||||||
|
|
||||||
|
using RowPtrBF = RowPtr<BF16>;
|
||||||
|
using RowPtrF = RowPtr<float>;
|
||||||
|
using RowPtrD = RowPtr<double>;
|
||||||
|
|
||||||
|
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||||
|
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||||
|
// the memory. Unlike `MatPtr`, this lacks metadata.
|
||||||
|
// TODO: replace with `MatStorageT`.
|
||||||
|
template <typename T>
|
||||||
|
class RowVectorBatch {
|
||||||
|
public:
|
||||||
|
// Default ctor for Activations ctor.
|
||||||
|
RowVectorBatch() = default;
|
||||||
|
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
||||||
|
// we default to tightly packed rows (`stride = cols`).
|
||||||
|
// WARNING: not all call sites support `stride` != cols.
|
||||||
|
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
||||||
|
RowVectorBatch(const Allocator2& allocator, Extents2D extents,
|
||||||
|
size_t stride = 0)
|
||||||
|
: extents_(extents) {
|
||||||
|
if (stride == 0) {
|
||||||
|
stride_ = extents_.cols;
|
||||||
|
} else {
|
||||||
|
HWY_ASSERT(stride >= extents_.cols);
|
||||||
|
stride_ = stride;
|
||||||
|
}
|
||||||
|
// Allow binding the entire matrix.
|
||||||
|
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
|
||||||
|
allocator.QuantumBytes() / sizeof(T));
|
||||||
|
mem_ = allocator.Alloc<T>(padded);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Move-only
|
||||||
|
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||||
|
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||||
|
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||||
|
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||||
|
|
||||||
|
size_t BatchSize() const { return extents_.rows; }
|
||||||
|
size_t Cols() const { return extents_.cols; }
|
||||||
|
size_t Stride() const { return stride_; }
|
||||||
|
Extents2D Extents() const { return extents_; }
|
||||||
|
|
||||||
|
// Returns the given row vector of length `Cols()`.
|
||||||
|
T* Batch(size_t batch_idx) {
|
||||||
|
HWY_DASSERT(batch_idx < BatchSize());
|
||||||
|
return mem_.get() + batch_idx * stride_;
|
||||||
|
}
|
||||||
|
const T* Batch(size_t batch_idx) const {
|
||||||
|
HWY_DASSERT(batch_idx < BatchSize());
|
||||||
|
return mem_.get() + batch_idx * stride_;
|
||||||
|
}
|
||||||
|
|
||||||
|
// For MatMul or other operations that process the entire batch at once.
|
||||||
|
// TODO: remove once we only use Mat.
|
||||||
|
T* All() { return mem_.get(); }
|
||||||
|
const T* Const() const { return mem_.get(); }
|
||||||
|
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
|
||||||
|
|
||||||
|
private:
|
||||||
|
AlignedPtr2<T[]> mem_;
|
||||||
|
Extents2D extents_;
|
||||||
|
size_t stride_;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
RowPtr<T> RowPtrFromBatch(const Allocator2& allocator,
|
||||||
|
RowVectorBatch<T>& row_vectors) {
|
||||||
|
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
|
||||||
|
row_vectors.Stride());
|
||||||
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
RowVectorBatch<T> AllocateAlignedRows(const Allocator2& allocator,
|
||||||
|
Extents2D extents) {
|
||||||
|
return RowVectorBatch<T>(
|
||||||
|
allocator, extents,
|
||||||
|
StrideForCyclicOffsets(extents.cols,
|
||||||
|
allocator.QuantumBytes() / sizeof(T)));
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace gcpp
|
||||||
|
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||||
|
|
@ -13,12 +13,14 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "util/threading.h"
|
#include "util/threading.h" // NOT threading_context..
|
||||||
|
// to ensure there is no deadlock.
|
||||||
|
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm> // std::sort
|
#include <algorithm> // std::sort
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <memory>
|
||||||
#include <optional>
|
#include <optional>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
|
@ -69,7 +71,7 @@ class Pinning {
|
||||||
const int bytes_written =
|
const int bytes_written =
|
||||||
snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", pkg_idx, cluster_idx,
|
snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", pkg_idx, cluster_idx,
|
||||||
static_cast<int>(task));
|
static_cast<int>(task));
|
||||||
HWY_ASSERT(bytes_written < sizeof(buf));
|
HWY_ASSERT(bytes_written < static_cast<int>(sizeof(buf)));
|
||||||
hwy::SetThreadName(buf, 0); // does not support varargs
|
hwy::SetThreadName(buf, 0); // does not support varargs
|
||||||
|
|
||||||
if (HWY_LIKELY(want_pin_)) {
|
if (HWY_LIKELY(want_pin_)) {
|
||||||
|
|
@ -107,16 +109,16 @@ static Pinning& GetPinning() {
|
||||||
return pinning;
|
return pinning;
|
||||||
}
|
}
|
||||||
|
|
||||||
static PoolPtr MakePool(size_t num_workers,
|
static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers,
|
||||||
std::optional<size_t> node = std::nullopt) {
|
std::optional<size_t> node = std::nullopt) {
|
||||||
// `ThreadPool` expects the number of threads to create, which is one less
|
// `ThreadPool` expects the number of threads to create, which is one less
|
||||||
// than the number of workers, but avoid underflow if zero.
|
// than the number of workers, but avoid underflow if zero.
|
||||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||||
PoolPtr ptr = Allocator::AllocClasses<hwy::ThreadPool>(1, num_threads);
|
PoolPtr ptr = allocator.AllocClasses<hwy::ThreadPool>(1, num_threads);
|
||||||
const size_t bytes =
|
const size_t bytes =
|
||||||
hwy::RoundUpTo(sizeof(hwy::ThreadPool), Allocator::QuantumBytes());
|
hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes());
|
||||||
if (node.has_value() && Allocator::ShouldBind()) {
|
if (node.has_value() && allocator.ShouldBind()) {
|
||||||
Allocator::BindMemory(ptr.get(), bytes, node.value());
|
allocator.BindMemory(ptr.get(), bytes, node.value());
|
||||||
}
|
}
|
||||||
return ptr;
|
return ptr;
|
||||||
}
|
}
|
||||||
|
|
@ -133,21 +135,21 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
||||||
return max;
|
return max;
|
||||||
}
|
}
|
||||||
|
|
||||||
NestedPools::NestedPools(const BoundedTopology& topology, size_t max_threads,
|
NestedPools::NestedPools(const BoundedTopology& topology,
|
||||||
|
const Allocator2& allocator, size_t max_threads,
|
||||||
Tristate pin) {
|
Tristate pin) {
|
||||||
GetPinning().SetPolicy(pin);
|
GetPinning().SetPolicy(pin);
|
||||||
packages_.resize(topology.NumPackages());
|
packages_.resize(topology.NumPackages());
|
||||||
all_packages_ = MakePool(packages_.size());
|
all_packages_ = MakePool(allocator, packages_.size());
|
||||||
const size_t max_workers_per_package =
|
const size_t max_workers_per_package =
|
||||||
DivideMaxAcross(max_threads, packages_.size());
|
DivideMaxAcross(max_threads, packages_.size());
|
||||||
// Each worker in all_packages_, including the main thread, will be the
|
// Each worker in all_packages_, including the main thread, will be the
|
||||||
// calling thread of an all_clusters[0].Run, and hence pinned to one of the
|
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||||
// `cluster.lps` if `pin`.
|
// `cluster.lps` if `pin`.
|
||||||
all_packages_[0].Run(
|
all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
|
||||||
0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
|
|
||||||
HWY_ASSERT(pkg_idx == thread); // each thread has one task
|
HWY_ASSERT(pkg_idx == thread); // each thread has one task
|
||||||
packages_[pkg_idx] =
|
packages_[pkg_idx] =
|
||||||
Package(topology, pkg_idx, max_workers_per_package);
|
Package(topology, allocator, pkg_idx, max_workers_per_package);
|
||||||
});
|
});
|
||||||
|
|
||||||
all_pinned_ = GetPinning().AllPinned(&pin_string_);
|
all_pinned_ = GetPinning().AllPinned(&pin_string_);
|
||||||
|
|
@ -172,28 +174,29 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
||||||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||||
}
|
}
|
||||||
|
|
||||||
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
|
NestedPools::Package::Package(const BoundedTopology& topology,
|
||||||
|
const Allocator2& allocator, size_t pkg_idx,
|
||||||
size_t max_workers_per_package) {
|
size_t max_workers_per_package) {
|
||||||
// Pre-allocate because elements are set concurrently.
|
// Pre-allocate because elements are set concurrently.
|
||||||
clusters_.resize(topology.NumClusters(pkg_idx));
|
clusters_.resize(topology.NumClusters(pkg_idx));
|
||||||
const size_t max_workers_per_cluster =
|
const size_t max_workers_per_cluster =
|
||||||
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
||||||
|
|
||||||
all_clusters_ =
|
all_clusters_ = MakePool(allocator, clusters_.size(),
|
||||||
MakePool(clusters_.size(), topology.GetCluster(pkg_idx, 0).Node());
|
topology.GetCluster(pkg_idx, 0).Node());
|
||||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||||
// `cluster.lps`.
|
// `cluster.lps`.
|
||||||
all_clusters_[0].Run(
|
all_clusters_->Run(
|
||||||
0, all_clusters_[0].NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
||||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||||
const BoundedTopology::Cluster& cluster =
|
const BoundedTopology::Cluster& cluster =
|
||||||
topology.GetCluster(pkg_idx, cluster_idx);
|
topology.GetCluster(pkg_idx, cluster_idx);
|
||||||
clusters_[cluster_idx] =
|
clusters_[cluster_idx] = MakePool(
|
||||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster),
|
allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster),
|
||||||
cluster.Node());
|
cluster.Node());
|
||||||
// Pin workers AND the calling thread from `all_clusters`.
|
// Pin workers AND the calling thread from `all_clusters`.
|
||||||
GetPinning().MaybePin(pkg_idx, cluster_idx, cluster,
|
GetPinning().MaybePin(pkg_idx, cluster_idx, cluster,
|
||||||
clusters_[cluster_idx][0]);
|
*clusters_[cluster_idx]);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -23,6 +23,7 @@
|
||||||
|
|
||||||
// IWYU pragma: begin_exports
|
// IWYU pragma: begin_exports
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
#include "util/args.h"
|
||||||
#include "util/basics.h" // Tristate
|
#include "util/basics.h" // Tristate
|
||||||
#include "util/topology.h"
|
#include "util/topology.h"
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
|
|
@ -37,7 +38,7 @@ namespace gcpp {
|
||||||
|
|
||||||
// Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows
|
// Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows
|
||||||
// moving because it is a typedef to `std::unique_ptr`.
|
// moving because it is a typedef to `std::unique_ptr`.
|
||||||
using PoolPtr = AlignedClassPtr<hwy::ThreadPool>;
|
using PoolPtr = AlignedClassPtr2<hwy::ThreadPool>;
|
||||||
|
|
||||||
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with
|
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with
|
||||||
// a thread per enabled package; for each of those, one with a thread per
|
// a thread per enabled package; for each of those, one with a thread per
|
||||||
|
|
@ -73,10 +74,8 @@ class NestedPools {
|
||||||
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
||||||
// only impose upper bounds on the number of detected packages and clusters
|
// only impose upper bounds on the number of detected packages and clusters
|
||||||
// rather than defining the actual number of threads.
|
// rather than defining the actual number of threads.
|
||||||
//
|
NestedPools(const BoundedTopology& topology, const Allocator2& allocator,
|
||||||
// Caller must have called `Allocator::Init` before this.
|
size_t max_threads = 0, Tristate pin = Tristate::kDefault);
|
||||||
NestedPools(const BoundedTopology& topology, size_t max_threads = 0,
|
|
||||||
Tristate pin = Tristate::kDefault);
|
|
||||||
|
|
||||||
bool AllPinned() const { return all_pinned_; }
|
bool AllPinned() const { return all_pinned_; }
|
||||||
|
|
||||||
|
|
@ -103,7 +102,7 @@ class NestedPools {
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t NumPackages() const { return packages_.size(); }
|
size_t NumPackages() const { return packages_.size(); }
|
||||||
hwy::ThreadPool& AllPackages() { return all_packages_[0]; }
|
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
||||||
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
|
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
|
||||||
HWY_DASSERT(pkg_idx < NumPackages());
|
HWY_DASSERT(pkg_idx < NumPackages());
|
||||||
return packages_[pkg_idx].AllClusters();
|
return packages_[pkg_idx].AllClusters();
|
||||||
|
|
@ -149,36 +148,36 @@ class NestedPools {
|
||||||
class Package {
|
class Package {
|
||||||
public:
|
public:
|
||||||
Package() = default; // for vector
|
Package() = default; // for vector
|
||||||
Package(const BoundedTopology& topology, size_t pkg_idx,
|
Package(const BoundedTopology& topology, const Allocator2& allocator,
|
||||||
size_t max_workers_per_package);
|
size_t pkg_idx, size_t max_workers_per_package);
|
||||||
|
|
||||||
size_t NumClusters() const { return clusters_.size(); }
|
size_t NumClusters() const { return clusters_.size(); }
|
||||||
size_t MaxWorkersPerCluster() const {
|
size_t MaxWorkersPerCluster() const {
|
||||||
size_t max_workers_per_cluster = 0;
|
size_t max_workers_per_cluster = 0;
|
||||||
for (const PoolPtr& cluster : clusters_) {
|
for (const PoolPtr& cluster : clusters_) {
|
||||||
max_workers_per_cluster =
|
max_workers_per_cluster =
|
||||||
HWY_MAX(max_workers_per_cluster, cluster[0].NumWorkers());
|
HWY_MAX(max_workers_per_cluster, cluster->NumWorkers());
|
||||||
}
|
}
|
||||||
return max_workers_per_cluster;
|
return max_workers_per_cluster;
|
||||||
}
|
}
|
||||||
size_t TotalWorkers() const {
|
size_t TotalWorkers() const {
|
||||||
size_t total_workers = 0;
|
size_t total_workers = 0;
|
||||||
for (const PoolPtr& cluster : clusters_) {
|
for (const PoolPtr& cluster : clusters_) {
|
||||||
total_workers += cluster[0].NumWorkers();
|
total_workers += cluster->NumWorkers();
|
||||||
}
|
}
|
||||||
return total_workers;
|
return total_workers;
|
||||||
}
|
}
|
||||||
|
|
||||||
hwy::ThreadPool& AllClusters() { return all_clusters_[0]; }
|
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
|
||||||
hwy::ThreadPool& Cluster(size_t cluster_idx) {
|
hwy::ThreadPool& Cluster(size_t cluster_idx) {
|
||||||
HWY_DASSERT(cluster_idx < clusters_.size());
|
HWY_DASSERT(cluster_idx < clusters_.size());
|
||||||
return clusters_[cluster_idx][0];
|
return *clusters_[cluster_idx];
|
||||||
}
|
}
|
||||||
|
|
||||||
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
||||||
all_clusters_[0].SetWaitMode(wait_mode);
|
all_clusters_->SetWaitMode(wait_mode);
|
||||||
for (PoolPtr& cluster : clusters_) {
|
for (PoolPtr& cluster : clusters_) {
|
||||||
cluster[0].SetWaitMode(wait_mode);
|
cluster->SetWaitMode(wait_mode);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -188,7 +187,7 @@ class NestedPools {
|
||||||
}; // Package
|
}; // Package
|
||||||
|
|
||||||
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
||||||
all_packages_[0].SetWaitMode(wait_mode);
|
all_packages_->SetWaitMode(wait_mode);
|
||||||
for (Package& package : packages_) {
|
for (Package& package : packages_) {
|
||||||
package.SetWaitMode(wait_mode);
|
package.SetWaitMode(wait_mode);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,13 @@ static std::mutex s_ctx_mutex;
|
||||||
s_ctx_mutex.unlock();
|
s_ctx_mutex.unlock();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*static*/ bool ThreadingContext2::IsInitialized() {
|
||||||
|
s_ctx_mutex.lock();
|
||||||
|
const bool initialized = !!s_ctx;
|
||||||
|
s_ctx_mutex.unlock();
|
||||||
|
return initialized;
|
||||||
|
}
|
||||||
|
|
||||||
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
|
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
|
||||||
// We do not bother with double-checked locking because it requires an
|
// We do not bother with double-checked locking because it requires an
|
||||||
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,10 @@ class ThreadingContext2 {
|
||||||
// is expected to be called early in the program, before threading starts.
|
// is expected to be called early in the program, before threading starts.
|
||||||
static void SetArgs(const ThreadingArgs& args);
|
static void SetArgs(const ThreadingArgs& args);
|
||||||
|
|
||||||
|
// Returns whether `Get()` has already been called, typically used to avoid
|
||||||
|
// calling `SetArgs` after that, because it would assert.
|
||||||
|
static bool IsInitialized();
|
||||||
|
|
||||||
// Returns a reference to the singleton after initializing it if necessary.
|
// Returns a reference to the singleton after initializing it if necessary.
|
||||||
// When initializing, uses the args passed to `SetArgs`, or defaults.
|
// When initializing, uses the args passed to `SetArgs`, or defaults.
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -13,8 +13,6 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include "util/threading.h"
|
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
#include <stdio.h>
|
#include <stdio.h>
|
||||||
|
|
||||||
|
|
@ -22,9 +20,9 @@
|
||||||
|
|
||||||
#include "gmock/gmock.h"
|
#include "gmock/gmock.h"
|
||||||
#include "gtest/gtest.h"
|
#include "gtest/gtest.h"
|
||||||
#include "util/allocator.h"
|
|
||||||
#include "util/basics.h"
|
#include "util/basics.h"
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "util/threading_context.h"
|
||||||
|
#include "hwy/aligned_allocator.h" // Span
|
||||||
#include "hwy/auto_tune.h"
|
#include "hwy/auto_tune.h"
|
||||||
#include "hwy/base.h" // HWY_ASSERT
|
#include "hwy/base.h" // HWY_ASSERT
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -385,9 +383,7 @@ TEST(ThreadingTest, BenchJoin) {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
BoundedTopology topology;
|
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||||
Allocator::Init(topology, true);
|
|
||||||
NestedPools pools(topology);
|
|
||||||
// Use last package because the main thread has been pinned to it.
|
// Use last package because the main thread has been pinned to it.
|
||||||
const size_t pkg_idx = pools.NumPackages() - 1;
|
const size_t pkg_idx = pools.NumPackages() - 1;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue