mirror of https://github.com/google/gemma.cpp.git
Major refactor of allocator/args:
use new ThreadingContext2 instead of monostate/init in each frontend Add ThreadingArgs(replaces AppArgs) backprop: use Packed() accessor and MakePacked factory and row-based access to allow for stride compress_weights: remove, moving to py-only exporter instead Move MatPtr to mat.h and revise interface: - Generic MatOwner - rename accessors to Packed* - support stride/row accessors, fix RowPtr stride Add TypeBits(Type) Move GenerateMat to test_util-inl for sharing between matmul test/bench Move internal init to gemma.cc to avoid duplication Rename GemmaEnv model_ to gemma_ for disambiguating vs upcoming ModelStorage Remove --compressed_weights, use --weights instead. tensor_index: add ExtentsFromInfo and TensorIndexLLM/Img Allocator: use normal unique_ptr for AllocBytes so users can call directly threading: use -> because AlignedPtr no longer assumes arrays PiperOrigin-RevId: 745918637
This commit is contained in:
parent
bef91a3f03
commit
8532da47f7
220
BUILD.bazel
220
BUILD.bazel
|
|
@ -19,7 +19,10 @@ license(
|
|||
# Dual-licensed Apache 2 and 3-clause BSD.
|
||||
licenses(["notice"])
|
||||
|
||||
exports_files(["LICENSE"])
|
||||
exports_files([
|
||||
"LICENSE",
|
||||
".github/workflows/build.yml",
|
||||
])
|
||||
|
||||
cc_library(
|
||||
name = "basics",
|
||||
|
|
@ -29,6 +32,16 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:io", # Path
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
# Split from :threading to break a circular dependency with :allocator.
|
||||
cc_library(
|
||||
name = "topology",
|
||||
|
|
@ -59,6 +72,7 @@ cc_library(
|
|||
hdrs = ["util/threading.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":args",
|
||||
":basics",
|
||||
":topology",
|
||||
# Placeholder for container detection, do not remove
|
||||
|
|
@ -68,14 +82,26 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "threading_context",
|
||||
srcs = ["util/threading_context.cc"],
|
||||
hdrs = ["util/threading_context.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":args",
|
||||
":basics",
|
||||
":threading",
|
||||
":topology",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "threading_test",
|
||||
srcs = ["util/threading_test.cc"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":threading",
|
||||
"@googletest//:gtest_main",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:auto_tune",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
|
|
@ -97,6 +123,65 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
"gemma/common.cc",
|
||||
"gemma/configs.cc",
|
||||
"gemma/tensor_index.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/tensor_index.h",
|
||||
],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:fields",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "configs_test",
|
||||
srcs = ["gemma/configs_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_index_test",
|
||||
srcs = ["gemma/tensor_index_test.cc"],
|
||||
deps = [
|
||||
":basics",
|
||||
":common",
|
||||
":weights",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mat",
|
||||
srcs = ["util/mat.cc"],
|
||||
hdrs = ["util/mat.h"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":common",
|
||||
":threading_context",
|
||||
"//compression:fields",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:profiler",
|
||||
],
|
||||
)
|
||||
|
||||
# For building all tests in one command, so we can test several.
|
||||
test_suite(
|
||||
name = "ops_tests",
|
||||
|
|
@ -123,8 +208,9 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":threading",
|
||||
":topology",
|
||||
":threading_context",
|
||||
"//compression:compress",
|
||||
"@highway//:algo",
|
||||
"@highway//:bit_set",
|
||||
|
|
@ -148,10 +234,9 @@ cc_test(
|
|||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":app",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
|
|
@ -174,13 +259,13 @@ cc_test(
|
|||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":app",
|
||||
":basics",
|
||||
":common",
|
||||
":mat",
|
||||
":ops",
|
||||
":test_util",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark", #buildcleaner: keep
|
||||
|
|
@ -196,6 +281,7 @@ cc_test(
|
|||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":mat",
|
||||
":ops",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
|
|
@ -214,12 +300,13 @@ cc_test(
|
|||
# for test_suite.
|
||||
tags = ["ops_tests"],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":mat",
|
||||
":ops",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -238,12 +325,12 @@ cc_test(
|
|||
"ops_tests", # for test_suite.
|
||||
],
|
||||
deps = [
|
||||
":allocator",
|
||||
":basics",
|
||||
":ops",
|
||||
":threading",
|
||||
":threading_context",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:compress",
|
||||
"//compression:test_util",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
|
|
@ -252,55 +339,13 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "common",
|
||||
srcs = [
|
||||
"gemma/common.cc",
|
||||
"gemma/configs.cc",
|
||||
"gemma/tensor_index.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gemma/common.h",
|
||||
"gemma/configs.h",
|
||||
"gemma/tensor_index.h",
|
||||
],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:fields",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy", # base.h
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "configs_test",
|
||||
srcs = ["gemma/configs_test.cc"],
|
||||
deps = [
|
||||
":common",
|
||||
"@googletest//:gtest_main",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "tensor_index_test",
|
||||
srcs = ["gemma/tensor_index_test.cc"],
|
||||
deps = [
|
||||
":basics",
|
||||
":common",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "weights",
|
||||
srcs = ["gemma/weights.cc"],
|
||||
hdrs = ["gemma/weights.h"],
|
||||
deps = [
|
||||
":common",
|
||||
":mat",
|
||||
"//compression:blob_store",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
|
|
@ -361,16 +406,17 @@ cc_library(
|
|||
":basics",
|
||||
":common",
|
||||
":ops",
|
||||
":mat",
|
||||
":tokenizer",
|
||||
":kv_cache",
|
||||
":weights",
|
||||
":threading",
|
||||
"//compression:compress",
|
||||
":threading_context",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
"@highway//:bit_set",
|
||||
"@highway//:nanobenchmark", # timer
|
||||
"@highway//:profiler",
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -390,25 +436,14 @@ cc_library(
|
|||
)
|
||||
|
||||
cc_library(
|
||||
name = "args",
|
||||
hdrs = ["util/args.h"],
|
||||
deps = [
|
||||
":basics",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "app",
|
||||
hdrs = ["util/app.h"],
|
||||
name = "gemma_args",
|
||||
hdrs = ["gemma/gemma_args.h"],
|
||||
deps = [
|
||||
":args",
|
||||
":basics",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading",
|
||||
"//compression:io",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -420,20 +455,15 @@ cc_library(
|
|||
srcs = ["evals/benchmark_helper.cc"],
|
||||
hdrs = ["evals/benchmark_helper.h"],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":common",
|
||||
":cross_entropy",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":kv_cache",
|
||||
":ops",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
":threading_context",
|
||||
"@google_benchmark//:benchmark",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:topology",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -451,7 +481,7 @@ cc_test(
|
|||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
|
|
@ -470,8 +500,7 @@ cc_test(
|
|||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_lib",
|
||||
":tokenizer",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
],
|
||||
|
|
@ -481,14 +510,13 @@ cc_binary(
|
|||
name = "gemma",
|
||||
srcs = ["gemma/run.cc"],
|
||||
deps = [
|
||||
":app",
|
||||
":args",
|
||||
":benchmark_helper",
|
||||
":common",
|
||||
":gemma_args",
|
||||
":gemma_lib",
|
||||
":ops",
|
||||
":threading",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
":threading_context",
|
||||
"//compression:sfp",
|
||||
"//paligemma:image",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -594,10 +622,10 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":mat",
|
||||
":ops",
|
||||
":prompt",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"@highway//:dot",
|
||||
"@highway//:hwy", # base.h
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -614,9 +642,9 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":common",
|
||||
":mat",
|
||||
":prompt",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
)
|
||||
|
|
@ -631,11 +659,11 @@ cc_test(
|
|||
deps = [
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":mat",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
@ -652,17 +680,16 @@ cc_test(
|
|||
"mem": "28g",
|
||||
},
|
||||
deps = [
|
||||
":allocator",
|
||||
":backprop",
|
||||
":backprop_scalar",
|
||||
":common",
|
||||
":mat",
|
||||
":ops",
|
||||
":prompt",
|
||||
":sampler",
|
||||
":threading",
|
||||
":threading_context",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"//compression:compress",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
|
|
@ -676,6 +703,7 @@ cc_library(
|
|||
deps = [
|
||||
":allocator",
|
||||
":common",
|
||||
":mat",
|
||||
":weights",
|
||||
"//compression:compress",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -685,9 +713,7 @@ cc_library(
|
|||
|
||||
cc_test(
|
||||
name = "optimize_test",
|
||||
srcs = [
|
||||
"backprop/optimize_test.cc",
|
||||
],
|
||||
srcs = ["backprop/optimize_test.cc"],
|
||||
exec_properties = {
|
||||
# Avoid linker OOMs when building with sanitizer instrumentation.
|
||||
"mem": "28g",
|
||||
|
|
@ -704,7 +730,7 @@ cc_test(
|
|||
":sampler",
|
||||
":threading",
|
||||
":weights",
|
||||
"@googletest//:gtest_main",
|
||||
"@googletest//:gtest_main", # buildcleaner: keep
|
||||
"//compression:sfp",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
|
|
|
|||
|
|
@ -75,6 +75,7 @@ set(SOURCES
|
|||
gemma/common.h
|
||||
gemma/configs.cc
|
||||
gemma/configs.h
|
||||
gemma/gemma_args.h
|
||||
gemma/gemma-inl.h
|
||||
gemma/gemma.cc
|
||||
gemma/gemma.h
|
||||
|
|
@ -102,12 +103,14 @@ set(SOURCES
|
|||
paligemma/image.h
|
||||
util/allocator.cc
|
||||
util/allocator.h
|
||||
util/app.h
|
||||
util/args.h
|
||||
util/basics.h
|
||||
util/mat.cc
|
||||
util/mat.h
|
||||
util/test_util.h
|
||||
util/threading.cc
|
||||
util/threading.h
|
||||
util/threading_context.cc
|
||||
util/threading_context.h
|
||||
util/topology.cc
|
||||
util/topology.h
|
||||
)
|
||||
|
|
@ -197,8 +200,5 @@ endif() # GEMMA_ENABLE_TESTS
|
|||
|
||||
## Tools
|
||||
|
||||
add_executable(compress_weights compression/compress_weights.cc)
|
||||
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
|
||||
|
||||
add_executable(migrate_weights compression/migrate_weights.cc)
|
||||
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)
|
||||
|
|
|
|||
|
|
@ -20,24 +20,30 @@
|
|||
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h" // MatStorageT
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "util/mat.h" // MatStorageT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T>
|
||||
struct ForwardLayer {
|
||||
ForwardLayer(const LayerConfig& config, size_t seq_len)
|
||||
: input("input", seq_len, config.model_dim),
|
||||
pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
|
||||
qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
|
||||
att("att", seq_len * config.heads, seq_len),
|
||||
att_out("att_out", seq_len * config.heads, config.qkv_dim),
|
||||
att_post1("att_post1", seq_len, config.model_dim),
|
||||
attention_out("attention_out", seq_len, config.model_dim),
|
||||
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
|
||||
ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
|
||||
ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
|
||||
: input(MakePacked<T>("input", seq_len, config.model_dim)),
|
||||
pre_att_rms_out(
|
||||
MakePacked<T>("pre_att_rms_out", seq_len, config.model_dim)),
|
||||
qkv(MakePacked<T>("qkv", seq_len * (config.heads + 2), config.qkv_dim)),
|
||||
att(MakePacked<T>("att", seq_len * config.heads, seq_len)),
|
||||
att_out(
|
||||
MakePacked<T>("att_out", seq_len * config.heads, config.qkv_dim)),
|
||||
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
|
||||
attention_out(
|
||||
MakePacked<T>("attention_out", seq_len, config.model_dim)),
|
||||
bf_pre_ffw_rms_out(
|
||||
MakePacked<T>("bf_preFF_rms_out", seq_len, config.model_dim)),
|
||||
ffw_hidden(
|
||||
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
|
||||
ffw_hidden_gated(
|
||||
MakePacked<T>("ffw_hidden_gated", seq_len, config.ff_hidden_dim)),
|
||||
layer_config(config) {}
|
||||
|
||||
MatStorageT<T> input;
|
||||
|
|
@ -56,12 +62,12 @@ struct ForwardLayer {
|
|||
template <typename T>
|
||||
struct ForwardPass {
|
||||
ForwardPass(const ModelConfig& config)
|
||||
: final_layer_output("final_layer_output", config.seq_len,
|
||||
config.model_dim),
|
||||
final_norm_output("final_norm_output", config.seq_len,
|
||||
config.model_dim),
|
||||
logits("logits", config.seq_len, config.vocab_size),
|
||||
probs("probs", config.seq_len, config.vocab_size),
|
||||
: final_layer_output(
|
||||
MakePacked<T>("fin_layer_out", config.seq_len, config.model_dim)),
|
||||
final_norm_output(
|
||||
MakePacked<T>("fin_norm_out", config.seq_len, config.model_dim)),
|
||||
logits(MakePacked<T>("logits", config.seq_len, config.vocab_size)),
|
||||
probs(MakePacked<T>("probs", config.seq_len, config.vocab_size)),
|
||||
weights_config(config) {
|
||||
for (const auto& layer_config : config.layer_configs) {
|
||||
layers.emplace_back(layer_config, config.seq_len);
|
||||
|
|
|
|||
|
|
@ -128,7 +128,7 @@ static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
|
|||
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void RMSNormVJP(
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP(
|
||||
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
|
||||
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
|
||||
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
|
||||
|
|
@ -153,10 +153,9 @@ static HWY_NOINLINE void RMSNormVJP(
|
|||
}
|
||||
}
|
||||
|
||||
static HWY_NOINLINE void InputEmbeddingVJP(
|
||||
const float* weights, const std::vector<int>& prompt,
|
||||
const float scaling, const float* HWY_RESTRICT v,
|
||||
float* HWY_RESTRICT grad, size_t model_dim) {
|
||||
static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP(
|
||||
const float* weights, const std::vector<int>& prompt, const float scaling,
|
||||
const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) {
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
|
|
@ -182,17 +181,18 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
|
||||
HWY_ASSERT(num_tokens <= seq_len);
|
||||
|
||||
MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
|
||||
MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(),
|
||||
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
|
||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
|
||||
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * ff_hidden_dim * 2;
|
||||
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
|
||||
const float* HWY_RESTRICT f_out =
|
||||
forward.ffw_hidden.Packed() + hidden_offset;
|
||||
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
|
||||
const float* HWY_RESTRICT b_out_gated =
|
||||
backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
|
||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
|
||||
backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim;
|
||||
float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset;
|
||||
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
|
|
@ -206,38 +206,39 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
|
||||
num_tokens, grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
MatMulVJP(weights.gating_einsum_w.Packed(),
|
||||
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
|
||||
model_dim, ff_hidden_dim * 2, num_tokens,
|
||||
grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(),
|
||||
pool);
|
||||
RMSNormVJP(
|
||||
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
|
||||
backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens,
|
||||
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(next_layer_grad + pos * model_dim,
|
||||
backward.attention_out.data() + pos * model_dim, model_dim);
|
||||
backward.attention_out.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
backward.qkv.ZeroInit();
|
||||
ZeroInit(backward.qkv);
|
||||
|
||||
MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(), heads, qkv_dim, model_dim,
|
||||
num_tokens, grad.attn_vec_einsum_w.data(),
|
||||
backward.att_out.data(), pool);
|
||||
MultiHeadMatMulVJP(
|
||||
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
|
||||
backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens,
|
||||
grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool);
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
|
||||
const float* HWY_RESTRICT b_att_out =
|
||||
backward.att_out.data() + (pos * heads + head) * qkv_dim;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
backward.att_out.Packed() + (pos * heads + head) * qkv_dim;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
|
||||
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
|
||||
const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs;
|
||||
float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs;
|
||||
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
|
||||
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
|
||||
}
|
||||
|
|
@ -247,8 +248,8 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t aoffset = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
|
||||
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
|
||||
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
|
||||
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -257,13 +258,13 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
|
||||
const size_t aoffs = head * seq_len + pos * heads * seq_len;
|
||||
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
|
||||
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
|
||||
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
|
||||
const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs;
|
||||
const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs;
|
||||
float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
|
||||
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
|
||||
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
|
||||
const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs;
|
||||
float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs;
|
||||
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
|
||||
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
|
||||
}
|
||||
|
|
@ -272,28 +273,30 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
|
||||
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
|
||||
float* HWY_RESTRICT b_kv =
|
||||
backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
|
||||
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
|
||||
}
|
||||
|
||||
for (size_t head = 0; head < heads; ++head) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT b_q =
|
||||
backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
|
||||
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
|
||||
MulByConst(query_scale, b_q, qkv_dim);
|
||||
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
||||
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(), model_dim, num_tokens,
|
||||
grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
|
||||
MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
|
||||
backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens,
|
||||
grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(),
|
||||
pool);
|
||||
RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
|
||||
backward.pre_att_rms_out.Packed(), model_dim, num_tokens,
|
||||
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
|
||||
pool);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(backward.attention_out.data() + pos * model_dim,
|
||||
backward.input.data() + pos * model_dim, model_dim);
|
||||
AddFrom(backward.attention_out.Packed() + pos * model_dim,
|
||||
backward.input.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -353,47 +356,48 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
|
|||
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
|
||||
const size_t num_tokens = prompt.tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
|
||||
kVocabSize);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize,
|
||||
kVocabSize);
|
||||
SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize,
|
||||
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
|
||||
backward.logits.data() + pos * kVocabSize, kVocabSize);
|
||||
SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize,
|
||||
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJP(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(), backward.logits.data(), model_dim,
|
||||
kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(), pool);
|
||||
MatMulVJP(weights.embedder_input_embedding.Packed(),
|
||||
forward.final_norm_output.Packed(), backward.logits.Packed(),
|
||||
model_dim, kVocabSize, num_tokens,
|
||||
grad.embedder_input_embedding.Packed(),
|
||||
backward.final_norm_output.Packed(), pool);
|
||||
|
||||
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(), model_dim, num_tokens,
|
||||
grad.final_norm_scale.data(), backward.final_layer_output.data(),
|
||||
pool);
|
||||
RMSNormVJP(weights.final_norm_scale.Packed(),
|
||||
forward.final_layer_output.Packed(),
|
||||
backward.final_norm_output.Packed(), model_dim, num_tokens,
|
||||
grad.final_norm_scale.Packed(),
|
||||
backward.final_layer_output.Packed(), pool);
|
||||
|
||||
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
|
||||
auto layer_config = config.layer_configs[layer];
|
||||
// TODO(szabadka) Implement Griffin layer vjp.
|
||||
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
|
||||
float* next_layer_grad = layer + 1 < kLayers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
? backward.layers[layer + 1].input.Packed()
|
||||
: backward.final_layer_output.Packed();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
|
||||
inv_timescale, pool);
|
||||
}
|
||||
|
||||
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), model_dim);
|
||||
InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens,
|
||||
kEmbScaling, backward.layers[0].input.Packed(),
|
||||
grad.embedder_input_embedding.Packed(), model_dim);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -17,9 +17,8 @@
|
|||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@
|
|||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
|
|||
|
|
@ -211,62 +211,65 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
|
|||
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
|
||||
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
|
||||
|
||||
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
|
||||
grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
|
||||
kFFHiddenDim, num_tokens);
|
||||
MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy,
|
||||
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||
model_dim, kFFHiddenDim, num_tokens);
|
||||
|
||||
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
|
||||
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
|
||||
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
|
||||
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
|
||||
|
||||
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
|
||||
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
|
||||
MatMulVJPT(weights.gating_einsum_w.Packed(),
|
||||
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
|
||||
grad.gating_einsum_w.Packed(),
|
||||
backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
|
||||
backward.bf_pre_ffw_rms_out.data(),
|
||||
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
|
||||
model_dim, num_tokens);
|
||||
RMSNormVJPT(
|
||||
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
|
||||
backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(),
|
||||
backward.attention_out.Packed(), model_dim, num_tokens);
|
||||
|
||||
AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
|
||||
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);
|
||||
|
||||
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
|
||||
backward.attention_out.data(),
|
||||
grad.attn_vec_einsum_w.data(), backward.att_out.data(),
|
||||
kHeads, model_dim, qkv_dim, num_tokens);
|
||||
MultiHeadMatMulVJPT(
|
||||
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
|
||||
backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(),
|
||||
backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens);
|
||||
|
||||
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
|
||||
backward.att_out.data(), backward.qkv.data(),
|
||||
backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
|
||||
MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(),
|
||||
backward.att_out.Packed(), backward.qkv.Packed(),
|
||||
backward.att.Packed(), num_tokens, kHeads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
|
||||
backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
|
||||
MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens,
|
||||
kHeads, seq_len);
|
||||
|
||||
MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(),
|
||||
backward.qkv.Packed(), num_tokens, kHeads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
||||
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
|
||||
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
|
||||
}
|
||||
|
||||
for (int pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
|
||||
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= kHeads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
|
||||
backward.qkv.data(), grad.qkv_einsum_w.data(),
|
||||
backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
|
||||
num_tokens);
|
||||
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
|
||||
backward.pre_att_rms_out.data(),
|
||||
grad.pre_attention_norm_scale.data(), backward.input.data(),
|
||||
MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
|
||||
backward.qkv.Packed(), grad.qkv_einsum_w.Packed(),
|
||||
backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim,
|
||||
model_dim, num_tokens);
|
||||
RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
|
||||
backward.pre_att_rms_out.Packed(),
|
||||
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
AddFromT(backward.attention_out.data(), backward.input.data(),
|
||||
AddFromT(backward.attention_out.Packed(), backward.input.Packed(),
|
||||
num_tokens * model_dim);
|
||||
}
|
||||
|
||||
|
|
@ -307,41 +310,42 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
|
|||
const std::vector<int> tokens = prompt.tokens;
|
||||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
|
||||
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
|
||||
vocab_size);
|
||||
|
||||
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
|
||||
SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size,
|
||||
num_tokens);
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t i = 0; i < num_tokens; ++i) {
|
||||
SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
|
||||
backward.logits.data() + i * vocab_size, vocab_size);
|
||||
SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size,
|
||||
backward.logits.Packed() + i * vocab_size, vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
MatMulVJPT(
|
||||
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
|
||||
backward.logits.data(), grad.embedder_input_embedding.data(),
|
||||
backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
|
||||
MatMulVJPT(weights.embedder_input_embedding.Packed(),
|
||||
forward.final_norm_output.Packed(), backward.logits.Packed(),
|
||||
grad.embedder_input_embedding.Packed(),
|
||||
backward.final_norm_output.Packed(), vocab_size, model_dim,
|
||||
num_tokens);
|
||||
|
||||
RMSNormVJPT(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(),
|
||||
backward.final_norm_output.data(), grad.final_norm_scale.data(),
|
||||
backward.final_layer_output.data(), model_dim, num_tokens);
|
||||
RMSNormVJPT(
|
||||
weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(),
|
||||
backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(),
|
||||
backward.final_layer_output.Packed(), model_dim, num_tokens);
|
||||
|
||||
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
|
||||
T* next_layer_grad = layer + 1 < layers
|
||||
? backward.layers[layer + 1].input.data()
|
||||
: backward.final_layer_output.data();
|
||||
? backward.layers[layer + 1].input.Packed()
|
||||
: backward.final_layer_output.Packed();
|
||||
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
|
||||
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
|
||||
}
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
|
||||
kEmbScaling, backward.layers[0].input.data(),
|
||||
grad.embedder_input_embedding.data(), model_dim);
|
||||
InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens,
|
||||
kEmbScaling, backward.layers[0].input.Packed(),
|
||||
grad.embedder_input_embedding.Packed(), model_dim);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -31,9 +31,9 @@
|
|||
#include "backprop/prompt.h"
|
||||
#include "backprop/sampler.h"
|
||||
#include "backprop/test_util.h"
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -44,14 +44,14 @@ TEST(BackPropTest, MatMulVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", kRows, kCols);
|
||||
MatStorageT<T> x("x", kTokens, kCols);
|
||||
MatStorageT<T> grad("grad", kRows, kCols);
|
||||
MatStorageT<T> dx("dx", kTokens, kCols);
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
MatStorageT<T> dy("dy", kTokens, kRows);
|
||||
auto weights = MakePacked<T>("weights", kRows, kCols);
|
||||
auto x = MakePacked<T>("x", kTokens, kCols);
|
||||
auto grad = MakePacked<T>("grad", kRows, kCols);
|
||||
auto dx = MakePacked<T>("dx", kTokens, kCols);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
auto dy = MakePacked<T>("dy", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
|
|
@ -60,12 +60,13 @@ TEST(BackPropTest, MatMulVJP) {
|
|||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
|
||||
kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
||||
kRows, kCols, kTokens);
|
||||
ZeroInit(grad);
|
||||
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||
dx.Packed(), kRows, kCols, kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
|
||||
}
|
||||
|
|
@ -79,14 +80,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", kRows, kCols * kHeads);
|
||||
MatStorageT<T> x("x", kTokens, kCols * kHeads);
|
||||
MatStorageT<T> grad("grad", kRows, kCols * kHeads);
|
||||
MatStorageT<T> dx("dx", kTokens, kCols * kHeads);
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
MatStorageT<T> dy("dy", kTokens, kRows);
|
||||
auto weights = MakePacked<T>("weights", kRows, kCols * kHeads);
|
||||
auto x = MakePacked<T>("x", kTokens, kCols * kHeads);
|
||||
auto grad = MakePacked<T>("grad", kRows, kCols * kHeads);
|
||||
auto dx = MakePacked<T>("dx", kTokens, kCols * kHeads);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
auto dy = MakePacked<T>("dy", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
|
|
@ -95,13 +96,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
|
|||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
|
||||
kRows, kCols, kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
|
||||
dx.data(), kHeads, kRows, kCols, kTokens);
|
||||
ZeroInit(grad);
|
||||
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||
grad.Packed(), dx.Packed(), kHeads, kRows, kCols,
|
||||
kTokens);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
|
||||
}
|
||||
|
|
@ -113,14 +115,14 @@ TEST(BackPropTest, RMSNormVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", N, 1);
|
||||
MatStorageT<T> grad("grad", N, 1);
|
||||
MatStorageT<T> x("x", K, N);
|
||||
MatStorageT<T> dx("dx", K, N);
|
||||
MatStorageT<T> dy("dy", K, N);
|
||||
MatStorageT<TC> c_weights("c_weights", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", K, N);
|
||||
MatStorageT<TC> c_y("c_y", K, N);
|
||||
auto weights = MakePacked<T>("weights", N, 1);
|
||||
auto grad = MakePacked<T>("grad", N, 1);
|
||||
auto x = MakePacked<T>("x", K, N);
|
||||
auto dx = MakePacked<T>("dx", K, N);
|
||||
auto dy = MakePacked<T>("dy", K, N);
|
||||
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0 * (1 << iter), gen);
|
||||
|
|
@ -129,12 +131,12 @@ TEST(BackPropTest, RMSNormVJP) {
|
|||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), K * N);
|
||||
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
|
||||
N, K);
|
||||
ZeroInit(grad);
|
||||
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
|
||||
dx.Packed(), N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
|
|
@ -145,24 +147,24 @@ TEST(BackPropTest, SoftmaxVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", N, 1);
|
||||
MatStorageT<T> dx("dx", N, 1);
|
||||
MatStorageT<T> dy("dy", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", N, 1);
|
||||
MatStorageT<TC> c_y("c_y", N, 1);
|
||||
auto x = MakePacked<T>("x", N, 1);
|
||||
auto dx = MakePacked<T>("dx", N, 1);
|
||||
auto dy = MakePacked<T>("dy", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), c_x.SizeBytes());
|
||||
Softmax(c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
CopyMat(c_x, c_y);
|
||||
Softmax(c_y.Packed(), N);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||
};
|
||||
Softmax(x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), dx.SizeBytes());
|
||||
SoftmaxVJPT(x.data(), dx.data(), N);
|
||||
Softmax(x.Packed(), N);
|
||||
CopyMat(dy, dx);
|
||||
SoftmaxVJPT(x.Packed(), dx.Packed(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -175,26 +177,25 @@ TEST(BackPropTest, MaskedSoftmaxVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", N, 1);
|
||||
MatStorageT<T> dy("dy", N, 1);
|
||||
MatStorageT<T> dx("dx", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", N, 1);
|
||||
MatStorageT<TC> c_y("c_y", N, 1);
|
||||
dx.ZeroInit();
|
||||
auto x = MakePacked<T>("x", N, 1);
|
||||
auto dy = MakePacked<T>("dy", N, 1);
|
||||
auto dx = MakePacked<T>("dx", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||
ZeroInit(dx);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(),
|
||||
kTokens * kHeads * kSeqLen * sizeof(c_x.At(0)));
|
||||
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
CopyMat(c_x, c_y);
|
||||
MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||
};
|
||||
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
|
||||
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0)));
|
||||
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
|
||||
MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen);
|
||||
CopyMat(dy, dx);
|
||||
MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -204,11 +205,11 @@ TEST(BackPropTest, SoftcapVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", N, 1);
|
||||
MatStorageT<T> dx("dx", N, 1);
|
||||
MatStorageT<T> dy("dy", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", N, 1);
|
||||
MatStorageT<TC> c_y("c_y", N, 1);
|
||||
auto x = MakePacked<T>("x", N, 1);
|
||||
auto dx = MakePacked<T>("dx", N, 1);
|
||||
auto dy = MakePacked<T>("dy", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", N, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", N, 1);
|
||||
|
||||
constexpr float kCap = 30.0f;
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
|
|
@ -216,13 +217,13 @@ TEST(BackPropTest, SoftcapVJP) {
|
|||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0)));
|
||||
Softcap(kCap, c_y.data(), N);
|
||||
return DotT(dy.data(), c_y.data(), N);
|
||||
CopyMat(c_x, c_y);
|
||||
Softcap(kCap, c_y.Packed(), N);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N);
|
||||
};
|
||||
Softcap(kCap, x.data(), N);
|
||||
memcpy(dx.data(), dy.data(), dx.SizeBytes());
|
||||
SoftcapVJPT(kCap, x.data(), dx.data(), N);
|
||||
Softcap(kCap, x.Packed(), N);
|
||||
CopyMat(dy, dx);
|
||||
SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -233,9 +234,9 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", K, V);
|
||||
MatStorageT<T> dx("dx", K, V);
|
||||
MatStorageT<TC> c_x("c_x", K, V);
|
||||
auto x = MakePacked<T>("x", K, V);
|
||||
auto dx = MakePacked<T>("dx", K, V);
|
||||
auto c_x = MakePacked<TC>("c_x", K, V);
|
||||
Prompt prompt;
|
||||
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
|
||||
|
||||
|
|
@ -243,13 +244,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
|
|||
for (int iter = 0; iter < 10; ++iter) {
|
||||
prompt.context_size = 1 + (iter % 6);
|
||||
RandInit(x, 1.0 * (1 << iter), gen);
|
||||
Softcap(kCap, x.data(), V * K);
|
||||
Softmax(x.data(), V, K);
|
||||
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
|
||||
Softcap(kCap, x.Packed(), V * K);
|
||||
Softmax(x.Packed(), V, K);
|
||||
CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
return CrossEntropyLoss(c_x.data(), prompt, V);
|
||||
};
|
||||
auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); };
|
||||
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -260,21 +259,21 @@ TEST(BackPropTest, GatedGeluVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", K, 2 * N);
|
||||
MatStorageT<T> dx("dx", K, 2 * N);
|
||||
MatStorageT<T> dy("dy", K, N);
|
||||
MatStorageT<TC> c_x("c_x", K, 2 * N);
|
||||
MatStorageT<TC> c_y("c_y", K, N);
|
||||
auto x = MakePacked<T>("x", K, 2 * N);
|
||||
auto dx = MakePacked<T>("dx", K, 2 * N);
|
||||
auto dy = MakePacked<T>("dy", K, N);
|
||||
auto c_x = MakePacked<TC>("c_x", K, 2 * N);
|
||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
GatedGelu(c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), N * K);
|
||||
GatedGelu(c_x.Packed(), c_y.Packed(), N, K);
|
||||
return DotT(dy.Packed(), c_y.Packed(), N * K);
|
||||
};
|
||||
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
|
||||
GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K);
|
||||
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -289,25 +288,25 @@ TEST(BackPropTest, MaskedAttentionVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> x("x", kQKVSize, 1);
|
||||
MatStorageT<T> dx("dx", kQKVSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutSize, 1);
|
||||
MatStorageT<TC> c_x("c_x", kQKVSize, 1);
|
||||
MatStorageT<TC> c_y("c_y", kOutSize, 1);
|
||||
dx.ZeroInit();
|
||||
c_y.ZeroInit();
|
||||
auto x = MakePacked<T>("x", kQKVSize, 1);
|
||||
auto dx = MakePacked<T>("dx", kQKVSize, 1);
|
||||
auto dy = MakePacked<T>("dy", kOutSize, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", kQKVSize, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
|
||||
ZeroInit(dx);
|
||||
ZeroInit(c_y);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(x, 1.0, gen);
|
||||
Complexify(x, c_x);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
|
||||
MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim,
|
||||
kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
|
||||
};
|
||||
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
|
||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads,
|
||||
kQKVDim, kSeqLen);
|
||||
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -323,17 +322,17 @@ TEST(BackPropTest, MixByAttentionVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> qkv("qkv", kQKVSize, 1);
|
||||
MatStorageT<T> dqkv("dqkv", kQKVSize, 1);
|
||||
MatStorageT<T> attn("attn", kAttnSize, 1);
|
||||
MatStorageT<T> dattn("dattn", kAttnSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutSize, 1);
|
||||
MatStorageT<TC> c_qkv("c_qkv", kQKVSize, 1);
|
||||
MatStorageT<TC> c_attn("c_attn", kAttnSize, 1);
|
||||
MatStorageT<TC> c_y("c_y", kOutSize, 1);
|
||||
dqkv.ZeroInit();
|
||||
dattn.ZeroInit();
|
||||
c_y.ZeroInit();
|
||||
auto qkv = MakePacked<T>("qkv", kQKVSize, 1);
|
||||
auto dqkv = MakePacked<T>("dqkv", kQKVSize, 1);
|
||||
auto attn = MakePacked<T>("attn", kAttnSize, 1);
|
||||
auto dattn = MakePacked<T>("dattn", kAttnSize, 1);
|
||||
auto dy = MakePacked<T>("dy", kOutSize, 1);
|
||||
auto c_qkv = MakePacked<TC>("c_qkv", kQKVSize, 1);
|
||||
auto c_attn = MakePacked<TC>("c_attn", kAttnSize, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
|
||||
ZeroInit(dqkv);
|
||||
ZeroInit(dattn);
|
||||
ZeroInit(c_y);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(qkv, 1.0, gen);
|
||||
|
|
@ -342,12 +341,12 @@ TEST(BackPropTest, MixByAttentionVJP) {
|
|||
Complexify(attn, c_attn);
|
||||
RandInit(dy, 1.0, gen);
|
||||
auto func = [&]() {
|
||||
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
|
||||
kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
return DotT(dy.data(), c_y.data(), kOutSize);
|
||||
MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens,
|
||||
kHeads, kQKVDim, kSeqLen);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
|
||||
};
|
||||
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
|
||||
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(),
|
||||
dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen);
|
||||
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
|
||||
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
|
||||
}
|
||||
|
|
@ -360,11 +359,11 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
|||
std::mt19937 gen(42);
|
||||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
MatStorageT<T> weights("weights", kVocabSize, kModelDim);
|
||||
MatStorageT<T> grad("grad", kVocabSize, kModelDim);
|
||||
MatStorageT<T> dy("dy", kSeqLen, kModelDim);
|
||||
MatStorageT<TC> c_weights("c_weights", kVocabSize, kModelDim);
|
||||
MatStorageT<TC> c_y("c_y", kSeqLen, kModelDim);
|
||||
auto weights = MakePacked<T>("weights", kVocabSize, kModelDim);
|
||||
auto grad = MakePacked<T>("grad", kVocabSize, kModelDim);
|
||||
auto dy = MakePacked<T>("dy", kSeqLen, kModelDim);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kVocabSize, kModelDim);
|
||||
auto c_y = MakePacked<TC>("c_y", kSeqLen, kModelDim);
|
||||
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
|
||||
size_t num_tokens = tokens.size() - 1;
|
||||
|
||||
|
|
@ -373,12 +372,13 @@ TEST(BackPropTest, InputEmbeddingVJP) {
|
|||
RandInit(dy, 1.0, gen);
|
||||
Complexify(weights, c_weights);
|
||||
auto func = [&]() {
|
||||
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
|
||||
};
|
||||
grad.ZeroInit();
|
||||
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
|
||||
InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(),
|
||||
kModelDim);
|
||||
return DotT(dy.Packed(), c_y.Packed(), num_tokens * kModelDim);
|
||||
};
|
||||
ZeroInit(grad);
|
||||
InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(),
|
||||
grad.Packed(), kModelDim);
|
||||
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
|
||||
}
|
||||
}
|
||||
|
|
@ -410,8 +410,7 @@ TEST(BackPropTest, LayerVJP) {
|
|||
using T = double;
|
||||
using TC = std::complex<T>;
|
||||
ModelConfig config = TestConfig();
|
||||
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
|
||||
/*reshape_att=*/false);
|
||||
const TensorIndex tensor_index = TensorIndexLLM(config, size_t{0});
|
||||
const size_t kOutputSize = config.seq_len * config.model_dim;
|
||||
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
|
||||
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
|
||||
|
|
@ -419,15 +418,15 @@ TEST(BackPropTest, LayerVJP) {
|
|||
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
|
||||
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
|
||||
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
|
||||
MatStorageT<T> y("y", kOutputSize, 1);
|
||||
MatStorageT<T> dy("dy", kOutputSize, 1);
|
||||
MatStorageT<TC> c_y("c_y", kOutputSize, 1);
|
||||
auto y = MakePacked<T>("y", kOutputSize, 1);
|
||||
auto dy = MakePacked<T>("dy", kOutputSize, 1);
|
||||
auto c_y = MakePacked<TC>("c_y", kOutputSize, 1);
|
||||
const size_t num_tokens = 3;
|
||||
std::vector<MatStorage> layer_storage;
|
||||
std::vector<MatOwner> layer_storage;
|
||||
weights.Allocate(layer_storage);
|
||||
grad.Allocate(layer_storage);
|
||||
c_weights.Allocate(layer_storage);
|
||||
backward.input.ZeroInit();
|
||||
ZeroInit(backward.input);
|
||||
|
||||
for (size_t iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0, gen);
|
||||
|
|
@ -436,12 +435,12 @@ TEST(BackPropTest, LayerVJP) {
|
|||
Complexify(weights, c_weights);
|
||||
Complexify(forward.input, c_forward.input);
|
||||
auto func = [&]() {
|
||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
|
||||
return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
|
||||
ApplyLayer(c_weights, c_forward, num_tokens, c_y.Packed());
|
||||
return DotT(dy.Packed(), c_y.Packed(), num_tokens * config.model_dim);
|
||||
};
|
||||
grad.ZeroInit(/*layer_idx=*/0);
|
||||
ApplyLayer(weights, forward, num_tokens, y.data());
|
||||
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
|
||||
ApplyLayer(weights, forward, num_tokens, y.Packed());
|
||||
LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens);
|
||||
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
|
||||
__LINE__);
|
||||
TestGradient(grad, c_weights, func, 1e-11);
|
||||
|
|
|
|||
|
|
@ -33,8 +33,10 @@
|
|||
#include "backprop/test_util.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "ops/ops.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// clang-format off
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
|
|
@ -46,33 +48,45 @@
|
|||
// After highway.h
|
||||
#include "backprop/backward-inl.h"
|
||||
#include "backprop/forward-inl.h"
|
||||
#include "compression/compress.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "util/allocator.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
hwy::ThreadPool& ThreadHostileGetPool() {
|
||||
// Assume this is only called at the top level, i.e. not in a thread. Then we
|
||||
// can safely call `SetArgs` only once, because it would assert otherwise.
|
||||
// This is preferable to calling `ThreadHostileInvalidate`, because we would
|
||||
// repeat the topology initialization for every test.
|
||||
if (!ThreadingContext2::IsInitialized()) {
|
||||
gcpp::ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 8;
|
||||
threading_args.pin = Tristate::kFalse;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
}
|
||||
return ThreadingContext2::Get().pools.Pool();
|
||||
}
|
||||
|
||||
void TestMatMulVJP() {
|
||||
static const size_t kRows = 8;
|
||||
static const size_t kCols = 64;
|
||||
static const size_t kTokens = 5;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
std::mt19937 gen(42);
|
||||
MatStorageT<float> weights("weights", kRows, kCols);
|
||||
MatStorageT<float> x("x", kTokens, kCols);
|
||||
MatStorageT<float> dy("dy", kTokens, kRows);
|
||||
MatStorageT<float> grad("grad", kRows, kCols);
|
||||
MatStorageT<float> dx("dx", kTokens, kCols);
|
||||
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols);
|
||||
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols);
|
||||
auto weights = MakePacked<float>("weights", kRows, kCols);
|
||||
auto x = MakePacked<float>("x", kTokens, kCols);
|
||||
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||
auto grad = MakePacked<float>("grad", kRows, kCols);
|
||||
auto dx = MakePacked<float>("dx", kTokens, kCols);
|
||||
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols);
|
||||
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols);
|
||||
using TC = std::complex<double>;
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
|
|
@ -81,19 +95,20 @@ void TestMatMulVJP() {
|
|||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
|
||||
kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
|
||||
grad.data(), dx.data(), pools.Pool());
|
||||
ZeroInit(grad);
|
||||
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
|
||||
grad.Packed(), dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
grad_scalar.ZeroInit();
|
||||
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kRows, kCols, kTokens);
|
||||
ZeroInit(grad_scalar);
|
||||
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||
dx_scalar.Packed(), kRows, kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
|
|
@ -104,21 +119,19 @@ void TestMultiHeadMatMulVJP() {
|
|||
static const size_t kCols = 16;
|
||||
static const size_t kHeads = 4;
|
||||
static const size_t kTokens = 3;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
std::mt19937 gen(42);
|
||||
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
|
||||
MatStorageT<float> x("x", kTokens, kCols * kHeads);
|
||||
MatStorageT<float> grad("grad", kRows, kCols * kHeads);
|
||||
MatStorageT<float> dx("dx", kTokens, kCols * kHeads);
|
||||
MatStorageT<float> dy("dy", kTokens, kRows);
|
||||
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols * kHeads);
|
||||
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols * kHeads);
|
||||
auto weights = MakePacked<float>("weights", kRows, kCols * kHeads);
|
||||
auto x = MakePacked<float>("x", kTokens, kCols * kHeads);
|
||||
auto grad = MakePacked<float>("grad", kRows, kCols * kHeads);
|
||||
auto dx = MakePacked<float>("dx", kTokens, kCols * kHeads);
|
||||
auto dy = MakePacked<float>("dy", kTokens, kRows);
|
||||
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols * kHeads);
|
||||
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols * kHeads);
|
||||
using TC = std::complex<double>;
|
||||
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
|
||||
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
|
||||
MatStorageT<TC> c_y("c_y", kTokens, kRows);
|
||||
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
|
||||
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
|
||||
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
|
|
@ -127,20 +140,21 @@ void TestMultiHeadMatMulVJP() {
|
|||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
return DotT(dy.data(), c_y.data(), kTokens * kRows);
|
||||
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
|
||||
kRows, kCols, kTokens);
|
||||
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
|
||||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
|
||||
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
|
||||
ZeroInit(grad);
|
||||
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
|
||||
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
grad_scalar.ZeroInit();
|
||||
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
|
||||
ZeroInit(grad_scalar);
|
||||
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
|
||||
grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows,
|
||||
kCols, kTokens);
|
||||
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
|
||||
}
|
||||
|
|
@ -149,21 +163,19 @@ void TestMultiHeadMatMulVJP() {
|
|||
void TestRMSNormVJP() {
|
||||
static const size_t K = 2;
|
||||
static const size_t N = 64;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
std::mt19937 gen(42);
|
||||
MatStorageT<float> weights("weights", N, 1);
|
||||
MatStorageT<float> x("x", K, N);
|
||||
MatStorageT<float> grad("grad", N, 1);
|
||||
MatStorageT<float> dx("dx", K, N);
|
||||
MatStorageT<float> dy("dy", K, N);
|
||||
MatStorageT<float> grad_scalar("grad_scalar", N, 1);
|
||||
MatStorageT<float> dx_scalar("dx_scalar", K, N);
|
||||
auto weights = MakePacked<float>("weights", N, 1);
|
||||
auto x = MakePacked<float>("x", K, N);
|
||||
auto grad = MakePacked<float>("grad", N, 1);
|
||||
auto dx = MakePacked<float>("dx", K, N);
|
||||
auto dy = MakePacked<float>("dy", K, N);
|
||||
auto grad_scalar = MakePacked<float>("grad_scalar", N, 1);
|
||||
auto dx_scalar = MakePacked<float>("dx_scalar", K, N);
|
||||
using TC = std::complex<double>;
|
||||
MatStorageT<TC> c_weights("c_weights", N, 1);
|
||||
MatStorageT<TC> c_x("c_x", K, N);
|
||||
MatStorageT<TC> c_y("c_y", K, N);
|
||||
auto c_weights = MakePacked<TC>("c_weights", N, 1);
|
||||
auto c_x = MakePacked<TC>("c_x", K, N);
|
||||
auto c_y = MakePacked<TC>("c_y", K, N);
|
||||
|
||||
for (int iter = 0; iter < 10; ++iter) {
|
||||
RandInit(weights, 1.0f * (1 << iter), gen);
|
||||
|
|
@ -172,19 +184,19 @@ void TestRMSNormVJP() {
|
|||
Complexify(weights, c_weights);
|
||||
Complexify(x, c_x);
|
||||
auto func = [&]() {
|
||||
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
|
||||
return DotT(dy.data(), c_y.data(), K * N);
|
||||
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
|
||||
return DotT(dy.Packed(), c_y.Packed(), K * N);
|
||||
};
|
||||
|
||||
grad.ZeroInit();
|
||||
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
|
||||
dx.data(), pools.Pool());
|
||||
ZeroInit(grad);
|
||||
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
|
||||
dx.Packed(), pool);
|
||||
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
|
||||
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
|
||||
|
||||
grad_scalar.ZeroInit();
|
||||
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
|
||||
dx_scalar.data(), N, K);
|
||||
ZeroInit(grad_scalar);
|
||||
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
|
||||
dx_scalar.Packed(), N, K);
|
||||
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
|
||||
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
|
||||
}
|
||||
|
|
@ -215,9 +227,7 @@ static ModelConfig TestConfig() {
|
|||
|
||||
void TestEndToEnd() {
|
||||
std::mt19937 gen(42);
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
Allocator::Init(topology);
|
||||
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
hwy::ThreadPool& pool = ThreadHostileGetPool();
|
||||
ModelConfig config = TestConfig();
|
||||
WeightsWrapper<float> weights(config);
|
||||
WeightsWrapper<float> grad(config);
|
||||
|
|
@ -232,7 +242,7 @@ void TestEndToEnd() {
|
|||
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
for (const Prompt& prompt : batch) {
|
||||
ReverseSequenceSampler::LogPrompt(prompt);
|
||||
|
|
@ -242,13 +252,13 @@ void TestEndToEnd() {
|
|||
|
||||
float loss1 = CrossEntropyLossForwardPass(
|
||||
prompt.tokens, prompt.context_size, weights.get(), forward1,
|
||||
inv_timescale, pools.Pool());
|
||||
inv_timescale, pool);
|
||||
|
||||
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
|
||||
|
||||
grad.ZeroInit();
|
||||
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
|
||||
backward, inv_timescale, pools.Pool());
|
||||
backward, inv_timescale, pool);
|
||||
|
||||
Complexify(weights.get(), c_weights.get());
|
||||
auto func = [&]() {
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@
|
|||
|
||||
#include <complex>
|
||||
|
||||
#include "compression/compress.h" // MatStorageT
|
||||
#include "util/mat.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -60,7 +60,9 @@ void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
|
|||
|
||||
template <typename T>
|
||||
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
|
||||
MulByConstAndAddT(c, x.data(), out.data(), x.NumElements());
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols());
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@
|
|||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
|
|
@ -50,16 +51,17 @@ HWY_BEFORE_NAMESPACE();
|
|||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename ArrayT>
|
||||
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
|
||||
template <typename T>
|
||||
void InputEmbedding(const MatPtrT<T>& weights, const std::vector<int>& prompt,
|
||||
const float scaling, float* HWY_RESTRICT output,
|
||||
size_t model_dim, size_t vocab_size) {
|
||||
const hn::ScalableTag<float> df;
|
||||
HWY_ASSERT(!prompt.empty());
|
||||
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
|
||||
int token = prompt[pos];
|
||||
DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size),
|
||||
token * model_dim, output + pos * model_dim,
|
||||
const auto span = weights.Span();
|
||||
HWY_ASSERT(span.num == model_dim * vocab_size);
|
||||
DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim,
|
||||
model_dim);
|
||||
MulByConst(scaling, output + pos * model_dim, model_dim);
|
||||
}
|
||||
|
|
@ -109,27 +111,27 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
|
||||
HWY_ASSERT(num_tokens <= kSeqLen);
|
||||
|
||||
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
|
||||
activations.input.data(), model_dim, num_tokens,
|
||||
activations.pre_att_rms_out.data(), pool);
|
||||
ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(),
|
||||
activations.input.Packed(), model_dim, num_tokens,
|
||||
activations.pre_att_rms_out.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
|
||||
activations.pre_att_rms_out.data() + pos * model_dim,
|
||||
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||
activations.pre_att_rms_out.Packed() + pos * model_dim,
|
||||
activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool);
|
||||
}
|
||||
const size_t num_tasks = kHeads * num_tokens;
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
float* HWY_RESTRICT k =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
Rope(k, kQKVDim, inv_timescale.Const(), pos);
|
||||
}
|
||||
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
|
||||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
Rope(q, kQKVDim, inv_timescale.Const(), pos);
|
||||
MulByConst(query_scale, q, kQKVDim);
|
||||
});
|
||||
|
|
@ -138,12 +140,12 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT q =
|
||||
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const float* HWY_RESTRICT k2 =
|
||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
}
|
||||
|
|
@ -153,7 +155,7 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||
Softmax(head_att, pos + 1);
|
||||
});
|
||||
|
||||
|
|
@ -161,51 +163,51 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
const size_t head = task % kHeads;
|
||||
const size_t pos = task / kHeads;
|
||||
const float* HWY_RESTRICT head_att =
|
||||
activations.att.data() + (pos * kHeads + head) * kSeqLen;
|
||||
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
|
||||
activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
float* HWY_RESTRICT v2 =
|
||||
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
float* HWY_RESTRICT v2 = activations.qkv.Packed() +
|
||||
(pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
});
|
||||
|
||||
activations.attention_out.ZeroInit();
|
||||
ZeroInit(activations.attention_out);
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
for (size_t head = 0; head < kHeads; ++head) {
|
||||
MatVec(
|
||||
weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
||||
MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
|
||||
kQKVDim,
|
||||
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
|
||||
activations.att_post1.data() + pos * model_dim, pool);
|
||||
AddFrom(activations.att_post1.data() + pos * model_dim,
|
||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
||||
activations.att_out.Packed() + pos * kHeads * kQKVDim +
|
||||
head * kQKVDim,
|
||||
activations.att_post1.Packed() + pos * model_dim, pool);
|
||||
AddFrom(activations.att_post1.Packed() + pos * model_dim,
|
||||
activations.attention_out.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.input.data() + pos * model_dim,
|
||||
activations.attention_out.data() + pos * model_dim, model_dim);
|
||||
AddFrom(activations.input.Packed() + pos * model_dim,
|
||||
activations.attention_out.Packed() + pos * model_dim, model_dim);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
|
||||
activations.attention_out.data(), model_dim, num_tokens,
|
||||
activations.bf_pre_ffw_rms_out.data(), pool);
|
||||
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
|
||||
activations.attention_out.Packed(), model_dim, num_tokens,
|
||||
activations.bf_pre_ffw_rms_out.Packed(), pool);
|
||||
const size_t kFFHiddenDim = config.ff_hidden_dim;
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
|
||||
activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
|
||||
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
|
||||
activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim,
|
||||
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
const size_t hidden_offset = pos * kFFHiddenDim * 2;
|
||||
const float* HWY_RESTRICT out =
|
||||
activations.ffw_hidden.data() + hidden_offset;
|
||||
activations.ffw_hidden.Packed() + hidden_offset;
|
||||
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
|
||||
float* HWY_RESTRICT out_gated =
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
|
||||
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim;
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using DF = hn::ScalableTag<float>;
|
||||
DF df;
|
||||
|
|
@ -217,11 +219,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
|
|||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
|
||||
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
|
||||
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim,
|
||||
output + pos * model_dim, pool);
|
||||
}
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
AddFrom(activations.attention_out.data() + pos * model_dim,
|
||||
AddFrom(activations.attention_out.Packed() + pos * model_dim,
|
||||
output + pos * model_dim, model_dim);
|
||||
}
|
||||
}
|
||||
|
|
@ -247,44 +249,43 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
|
|||
const size_t num_tokens = prompt.size() - 1;
|
||||
|
||||
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
|
||||
forward.layers[0].input.data(), model_dim, vocab_size);
|
||||
forward.layers[0].input.Packed(), model_dim, vocab_size);
|
||||
|
||||
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
|
||||
auto type = config.layer_configs[layer].type;
|
||||
// TODO(szabadka) Implement Griffin layer.
|
||||
HWY_ASSERT(type == LayerAttentionType::kGemma);
|
||||
float* HWY_RESTRICT output = layer + 1 < layers
|
||||
? forward.layers[layer + 1].input.data()
|
||||
: forward.final_layer_output.data();
|
||||
? forward.layers[layer + 1].input.Packed()
|
||||
: forward.final_layer_output.Packed();
|
||||
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
|
||||
num_tokens, output, inv_timescale, pool);
|
||||
}
|
||||
|
||||
ApplyRMSNorm(weights.final_norm_scale.data(),
|
||||
forward.final_layer_output.data(), model_dim, num_tokens,
|
||||
forward.final_norm_output.data(), pool);
|
||||
ApplyRMSNorm(weights.final_norm_scale.Packed(),
|
||||
forward.final_layer_output.Packed(), model_dim, num_tokens,
|
||||
forward.final_norm_output.Packed(), pool);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
|
||||
forward.final_norm_output.data() + pos * model_dim,
|
||||
forward.logits.data() + pos * vocab_size, pool);
|
||||
forward.final_norm_output.Packed() + pos * model_dim,
|
||||
forward.logits.Packed() + pos * vocab_size, pool);
|
||||
}
|
||||
|
||||
if (config.final_cap > 0.0f) {
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
||||
vocab_size);
|
||||
LogitsSoftCap(config.final_cap,
|
||||
forward.logits.Packed() + pos * vocab_size, vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
|
||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
||||
CopyMat(forward.logits, forward.probs);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
|
||||
Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size);
|
||||
}
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
|
||||
return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size,
|
||||
vocab_size, pool);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -17,9 +17,7 @@
|
|||
|
||||
#include "backprop/activations.h"
|
||||
#include "backprop/prompt.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
|
|
|
|||
|
|
@ -180,54 +180,59 @@ void ApplyLayer(const LayerWeightsPtrs<T>& weights,
|
|||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
|
||||
|
||||
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
|
||||
activations.pre_att_rms_out.data(), model_dim, num_tokens);
|
||||
RMSNormT(weights.pre_attention_norm_scale.Packed(),
|
||||
activations.input.Packed(), activations.pre_att_rms_out.Packed(),
|
||||
model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
|
||||
activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
|
||||
MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(),
|
||||
activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim,
|
||||
num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
||||
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
|
||||
for (size_t h = 0; h <= heads; ++h) {
|
||||
Rope(qkv + h * qkv_dim, qkv_dim, pos);
|
||||
}
|
||||
}
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
|
||||
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
|
||||
MulByConstT(query_scale, qkv, heads * qkv_dim);
|
||||
}
|
||||
|
||||
MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
|
||||
heads, qkv_dim, seq_len);
|
||||
MaskedAttention(activations.qkv.Packed(), activations.att.Packed(),
|
||||
num_tokens, heads, qkv_dim, seq_len);
|
||||
|
||||
MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
|
||||
MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len);
|
||||
|
||||
MixByAttention(activations.qkv.data(), activations.att.data(),
|
||||
activations.att_out.data(), num_tokens, heads, qkv_dim,
|
||||
MixByAttention(activations.qkv.Packed(), activations.att.Packed(),
|
||||
activations.att_out.Packed(), num_tokens, heads, qkv_dim,
|
||||
seq_len);
|
||||
|
||||
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
|
||||
activations.attention_out.data(), heads, model_dim, qkv_dim,
|
||||
MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(),
|
||||
activations.att_out.Packed(),
|
||||
activations.attention_out.Packed(), heads, model_dim, qkv_dim,
|
||||
num_tokens);
|
||||
|
||||
AddFromT(activations.input.data(), activations.attention_out.data(),
|
||||
AddFromT(activations.input.Packed(), activations.attention_out.Packed(),
|
||||
num_tokens * model_dim);
|
||||
|
||||
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
|
||||
activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
|
||||
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
|
||||
activations.attention_out.Packed(),
|
||||
activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
|
||||
activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
|
||||
MatMulT(weights.gating_einsum_w.Packed(),
|
||||
activations.bf_pre_ffw_rms_out.Packed(),
|
||||
activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim,
|
||||
num_tokens);
|
||||
|
||||
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
|
||||
ff_hidden_dim, num_tokens);
|
||||
GatedGelu(activations.ffw_hidden.Packed(),
|
||||
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
|
||||
model_dim, ff_hidden_dim, num_tokens);
|
||||
MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(),
|
||||
output, model_dim, ff_hidden_dim, num_tokens);
|
||||
|
||||
AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
|
||||
AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim);
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
|
|
@ -258,35 +263,35 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
|
|||
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
|
||||
|
||||
const T kEmbScaling = EmbeddingScaling(model_dim);
|
||||
InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
|
||||
forward.layers[0].input.data(), model_dim);
|
||||
InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling,
|
||||
forward.layers[0].input.Packed(), model_dim);
|
||||
|
||||
for (size_t layer = 0; layer < layers; ++layer) {
|
||||
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
|
||||
: forward.final_layer_output.data();
|
||||
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed()
|
||||
: forward.final_layer_output.Packed();
|
||||
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
|
||||
output);
|
||||
}
|
||||
|
||||
RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
|
||||
forward.final_norm_output.data(), model_dim, num_tokens);
|
||||
RMSNormT(weights.final_norm_scale.Packed(),
|
||||
forward.final_layer_output.Packed(),
|
||||
forward.final_norm_output.Packed(), model_dim, num_tokens);
|
||||
|
||||
MatMulT(weights.embedder_input_embedding.data(),
|
||||
forward.final_norm_output.data(), forward.logits.data(), vocab_size,
|
||||
model_dim, num_tokens);
|
||||
MatMulT(weights.embedder_input_embedding.Packed(),
|
||||
forward.final_norm_output.Packed(), forward.logits.Packed(),
|
||||
vocab_size, model_dim, num_tokens);
|
||||
|
||||
for (size_t pos = 0; pos < num_tokens; ++pos) {
|
||||
if (config.final_cap > 0.0f) {
|
||||
Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
|
||||
Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size,
|
||||
vocab_size);
|
||||
}
|
||||
}
|
||||
|
||||
memcpy(forward.probs.data(), forward.logits.data(),
|
||||
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
|
||||
Softmax(forward.probs.data(), vocab_size, num_tokens);
|
||||
CopyMat(forward.logits, forward.probs);
|
||||
Softmax(forward.probs.Packed(), vocab_size, num_tokens);
|
||||
|
||||
return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
|
||||
return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -41,11 +41,14 @@
|
|||
namespace gcpp {
|
||||
|
||||
TEST(OptimizeTest, GradientDescent) {
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
|
||||
Allocator::Init(topology);
|
||||
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
|
||||
MatMulEnv env(topology, pools);
|
||||
hwy::ThreadPool& pool = pools.Pool();
|
||||
gcpp::ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 1;
|
||||
threading_args.pin = Tristate::kFalse;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext2::Get());
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
std::mt19937 gen(42);
|
||||
|
||||
const ModelInfo info = {
|
||||
|
|
@ -64,7 +67,7 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
|
||||
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
|
||||
Gemma gemma(GemmaTokenizer(), info, env);
|
||||
|
|
|
|||
|
|
@ -18,9 +18,9 @@
|
|||
#include <cmath>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -39,11 +39,11 @@ class AdamUpdater {
|
|||
|
||||
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
|
||||
MatPtr& grad_m, MatPtr& grad_v) {
|
||||
const float* HWY_RESTRICT g = grad.data<float>();
|
||||
float* HWY_RESTRICT w = weights.data<float>();
|
||||
float* HWY_RESTRICT m = grad_m.data<float>();
|
||||
float* HWY_RESTRICT v = grad_v.data<float>();
|
||||
for (size_t i = 0; i < grad.NumElements(); ++i) {
|
||||
const float* HWY_RESTRICT g = grad.RowT<float>(0);
|
||||
float* HWY_RESTRICT w = weights.RowT<float>(0);
|
||||
float* HWY_RESTRICT m = grad_m.RowT<float>(0);
|
||||
float* HWY_RESTRICT v = grad_v.RowT<float>(0);
|
||||
for (size_t i = 0; i < grad.Extents().Area(); ++i) {
|
||||
m[i] *= beta1_;
|
||||
m[i] += cbeta1_ * g[i];
|
||||
v[i] *= beta2_;
|
||||
|
|
|
|||
|
|
@ -24,21 +24,13 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gtest/gtest.h"
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
template <typename T>
|
||||
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
||||
std::normal_distribution<T> dist(0.0, stddev);
|
||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
||||
x.At(i) = dist(gen);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: make a member of Layer<T>.
|
||||
template <typename T>
|
||||
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
||||
|
|
@ -62,8 +54,12 @@ void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
|
|||
|
||||
template <typename T, typename U>
|
||||
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
|
||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
||||
c_x.At(i) = std::complex<U>(x.At(i), 0.0);
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
const T* row = x.Row(r);
|
||||
std::complex<U>* c_row = c_x.Row(r);
|
||||
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||
c_row[c] = std::complex<U>(row[c], 0.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -87,14 +83,14 @@ void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
|
|||
}
|
||||
}
|
||||
|
||||
// Somewhat duplicates ModelWeightsStorage, but that has neither double nor
|
||||
// Somewhat duplicates WeightsOwner, but that has neither double nor
|
||||
// complex types allowed and it would cause code bloat to add them there.
|
||||
template <typename T>
|
||||
class WeightsWrapper {
|
||||
public:
|
||||
explicit WeightsWrapper(const ModelConfig& config)
|
||||
: pool_(0), weights_(config) {
|
||||
weights_.Allocate(data_, pool_);
|
||||
weights_.Allocate(owners_, pool_);
|
||||
}
|
||||
|
||||
const ModelWeightsPtrs<T>& get() const { return weights_; }
|
||||
|
|
@ -106,7 +102,7 @@ class WeightsWrapper {
|
|||
|
||||
private:
|
||||
hwy::ThreadPool pool_;
|
||||
std::vector<MatStorage> data_;
|
||||
std::vector<MatOwner> owners_;
|
||||
ModelWeightsPtrs<T> weights_;
|
||||
};
|
||||
|
||||
|
|
@ -116,13 +112,18 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
|||
double sum0 = 0;
|
||||
double sum1 = 0;
|
||||
double sum01 = 0;
|
||||
for (size_t i = 0; i < actual.NumElements(); ++i) {
|
||||
sum0 += actual.At(i) * actual.At(i);
|
||||
sum1 += expected.At(i) * expected.At(i);
|
||||
sum01 += actual.At(i) * expected.At(i);
|
||||
ASSERT_NEAR(actual.At(i), expected.At(i),
|
||||
std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err))
|
||||
<< "line: " << line << " dim=" << expected.NumElements() << " i=" << i;
|
||||
for (size_t r = 0; r < actual.Rows(); ++r) {
|
||||
const T* actual_row = actual.Row(r);
|
||||
const U* expected_row = expected.Row(r);
|
||||
for (size_t c = 0; c < actual.Cols(); ++c) {
|
||||
sum0 += actual_row[c] * actual_row[c];
|
||||
sum1 += expected_row[c] * expected_row[c];
|
||||
sum01 += actual_row[c] * expected_row[c];
|
||||
ASSERT_NEAR(
|
||||
actual_row[c], expected_row[c],
|
||||
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
|
||||
<< "line: " << line << " r " << r << " c " << c;
|
||||
}
|
||||
}
|
||||
if (sum0 > 1e-40) {
|
||||
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
|
||||
|
|
@ -148,15 +149,19 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
|
|||
template <typename FUNC, typename T, typename U>
|
||||
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
|
||||
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
|
||||
MatStorageT<T> exp_grad("exp_grad", x.Rows(), x.Cols());
|
||||
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
|
||||
const U inv_step = 1.0 / step;
|
||||
for (size_t i = 0; i < x.NumElements(); ++i) {
|
||||
const U x0 = std::real(x.At(i));
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
std::complex<U>* x_row = x.Row(r);
|
||||
T* exp_row = exp_grad.Row(r);
|
||||
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||
const U x0 = std::real(x_row[c]);
|
||||
const std::complex<U> x1 = std::complex<U>(x0, step);
|
||||
x.At(i) = x1;
|
||||
x_row[c] = x1;
|
||||
const std::complex<U> f1 = func();
|
||||
exp_grad.At(i) = std::imag(f1) * inv_step;
|
||||
x.At(i) = x0;
|
||||
exp_row[c] = std::imag(f1) * inv_step;
|
||||
x_row[c] = x0;
|
||||
}
|
||||
}
|
||||
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -146,6 +146,7 @@ cc_library(
|
|||
":distortion",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -209,6 +210,7 @@ cc_library(
|
|||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:common",
|
||||
"//:mat",
|
||||
"@highway//:hwy",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:stats",
|
||||
|
|
@ -252,22 +254,6 @@ cc_library(
|
|||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "compress_weights",
|
||||
srcs = ["compress_weights.cc"],
|
||||
deps = [
|
||||
":compress",
|
||||
":io",
|
||||
"//:allocator",
|
||||
"//:args",
|
||||
"//:common",
|
||||
"//:tokenizer",
|
||||
"//:weights",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
cc_binary(
|
||||
name = "blob_compare",
|
||||
srcs = ["blob_compare.cc"],
|
||||
|
|
@ -277,9 +263,11 @@ cc_binary(
|
|||
"//:allocator",
|
||||
"//:basics",
|
||||
"//:threading",
|
||||
"//:threading_context",
|
||||
"@highway//:hwy",
|
||||
"@highway//:hwy_test_util",
|
||||
"@highway//:nanobenchmark",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
@ -287,7 +275,6 @@ cc_binary(
|
|||
name = "migrate_weights",
|
||||
srcs = ["migrate_weights.cc"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:benchmark_helper",
|
||||
"//:gemma_lib",
|
||||
|
|
|
|||
|
|
@ -25,8 +25,10 @@
|
|||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // IndexRange
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -202,15 +204,13 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
|
|||
if (!CompareKeys(reader1, reader2)) return;
|
||||
|
||||
// Single allocation, avoid initializing the memory.
|
||||
BoundedTopology topology;
|
||||
Allocator::Init(topology);
|
||||
NestedPools pools(topology);
|
||||
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
|
||||
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
|
||||
size_t pos = 0;
|
||||
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
|
||||
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
|
||||
|
||||
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
|
||||
|
||||
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include "compression/compress.h" // IWYU pragma: export
|
||||
#include "compression/distortion.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -379,7 +380,7 @@ struct CompressTraits<SfpStream> {
|
|||
using Packed = SfpStream;
|
||||
|
||||
// Callers are responsible for scaling `raw` such that its magnitudes do not
|
||||
// exceed `SfpStream::kMax`. See CompressedArray::scale().
|
||||
// exceed `SfpStream::kMax`. See CompressedArray::Scale().
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
|
||||
size_t num, CompressPerThread& tls,
|
||||
|
|
@ -522,8 +523,7 @@ HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
|
|||
CompressWorkingSet& work,
|
||||
MatStorageT<Packed>& compressed,
|
||||
hwy::ThreadPool& pool) {
|
||||
Compress(raw, num, work,
|
||||
MakeSpan(compressed.data(), compressed.NumElements()),
|
||||
Compress(raw, num, work, compressed.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
}
|
||||
|
||||
|
|
@ -717,11 +717,9 @@ class Compressor {
|
|||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
|
||||
const float* HWY_RESTRICT weights) {
|
||||
size_t num_weights = compressed->NumElements();
|
||||
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
|
||||
return;
|
||||
size_t num_compressed = compressed->NumElements();
|
||||
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
|
||||
size_t num_weights = compressed->Extents().Area();
|
||||
if (num_weights == 0 || weights == nullptr || !compressed->HasPtr()) return;
|
||||
PackedSpan<Packed> packed = compressed->Span();
|
||||
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
|
||||
num_weights / (1000 * 1000));
|
||||
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
MatPtr::~MatPtr() {}
|
||||
// TODO: move ScaleWeights here.
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@
|
|||
// IWYU pragma: end_exports
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/per_target.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#if COMPRESS_STATS
|
||||
#include "compression/distortion.h"
|
||||
#include "hwy/stats.h"
|
||||
|
|
@ -49,322 +50,6 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
// Base class for rank-1 or 2 tensors (vector or matrix).
|
||||
// Supports both dynamic and compile-time sizing.
|
||||
// Holds metadata and a non-owning pointer to the data, owned by the derived
|
||||
// MatStorageT class.
|
||||
// This class also provides easy conversion from/to a table of contents for a
|
||||
// BlobStore file, and a templated (compile-time) accessor for a 2-d array of
|
||||
// fixed inner dimension and type.
|
||||
// It is designed to be put in a vector, and has default copy and operator=, so
|
||||
// it is easy to read/write a blob_store file.
|
||||
class MatPtr : public IFields {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
|
||||
size_t cols)
|
||||
: name_(name),
|
||||
type_(type),
|
||||
element_size_(element_size),
|
||||
num_elements_(rows * cols),
|
||||
rows_(rows),
|
||||
cols_(cols),
|
||||
ptr_(nullptr) {
|
||||
stride_ = cols;
|
||||
}
|
||||
// Default is to leave all fields default-initialized.
|
||||
MatPtr() = default;
|
||||
virtual ~MatPtr();
|
||||
|
||||
// Compatibility interface for CompressedArray.
|
||||
// TODO: remove.
|
||||
template <typename T>
|
||||
T* data() {
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_);
|
||||
}
|
||||
template <typename T>
|
||||
const T* data() const {
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_);
|
||||
}
|
||||
|
||||
const void* Ptr() const { return ptr_; }
|
||||
void* Ptr() { return ptr_; }
|
||||
// Sets the pointer from another MatPtr.
|
||||
void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; }
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtr(const MatPtr& other) = default;
|
||||
MatPtr& operator=(const MatPtr& other) = default;
|
||||
|
||||
// Returns the name of the blob.
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
void SetName(const std::string& name) { name_ = name; }
|
||||
|
||||
// Returns the type of the blob.
|
||||
Type GetType() const { return type_; }
|
||||
|
||||
// Returns the size of each element in bytes.
|
||||
size_t ElementSize() const { return element_size_; }
|
||||
|
||||
// Returns the number of elements in the array.
|
||||
size_t NumElements() const { return num_elements_; }
|
||||
|
||||
// Returns the number of bytes in the array.
|
||||
size_t SizeBytes() const {
|
||||
if (this->GetType() == TypeEnum<NuqStream>()) {
|
||||
return NuqStream::PackedEnd(num_elements_);
|
||||
}
|
||||
return num_elements_ * element_size_;
|
||||
}
|
||||
|
||||
// Returns the number of rows in the 2-d array (outer dimension).
|
||||
size_t Rows() const { return rows_; }
|
||||
|
||||
// Returns the number of columns in the 2-d array (inner dimension).
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||
|
||||
// Currently same as cols, but may differ in the future. This is the offset by
|
||||
// which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
// Decoded elements should be multiplied by this to restore their original
|
||||
// range. This is required because SfpStream can only encode a limited range
|
||||
// of magnitudes.
|
||||
float scale() const { return scale_; }
|
||||
void set_scale(float scale) { scale_ = scale; }
|
||||
|
||||
std::string LayerName(int layer) const {
|
||||
std::string name = name_ + std::to_string(layer);
|
||||
HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t));
|
||||
return name;
|
||||
}
|
||||
|
||||
// Sets all data to zero.
|
||||
void ZeroInit() {
|
||||
if (ptr_ == nullptr)
|
||||
HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str());
|
||||
hwy::ZeroBytes(ptr_, SizeBytes());
|
||||
}
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
visitor(name_);
|
||||
visitor(type_);
|
||||
visitor(element_size_);
|
||||
visitor(num_elements_);
|
||||
visitor(rows_);
|
||||
visitor(cols_);
|
||||
visitor(scale_);
|
||||
visitor(stride_);
|
||||
}
|
||||
|
||||
// Calls func on the upcasted type. Since MatPtr by design is not templated,
|
||||
// here we provide a way to get to the derived type, provided that `Type()`
|
||||
// is one of the strings returned by `TypeName()`.
|
||||
template <class FuncT, typename... TArgs>
|
||||
decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args);
|
||||
|
||||
protected:
|
||||
// Arbitrary name for the array of preferably <= 16 characters.
|
||||
std::string name_;
|
||||
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
|
||||
Type type_;
|
||||
// sizeof(T)
|
||||
uint32_t element_size_ = 0;
|
||||
// Number of elements in the array.
|
||||
uint32_t num_elements_ = 0; // In element_size units.
|
||||
// Number of rows in the 2-d array (outer dimension).
|
||||
uint32_t rows_ = 0;
|
||||
// Number of columns in the 2-d array (inner dimension).
|
||||
uint32_t cols_ = 0;
|
||||
// Scaling to apply to each element.
|
||||
float scale_ = 1.0f;
|
||||
// Aligned data array. This is always a borrowed pointer. It should never be
|
||||
// freed. The underlying memory is owned by a subclass or some external class
|
||||
// and must outlive this object.
|
||||
void* ptr_ = nullptr;
|
||||
|
||||
uint32_t stride_;
|
||||
};
|
||||
|
||||
// MatPtrT adds a single template argument to MatPtr for an explicit type.
|
||||
// Use this class as a function argument where the type needs to be known.
|
||||
// Use MatPtr where the type does not need to be known.
|
||||
template <typename MatT>
|
||||
class MatPtrT : public MatPtr {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatPtrT(const std::string& name, size_t rows, size_t cols)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
|
||||
// Construction from TensorIndex entry to remove duplication of sizes.
|
||||
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
|
||||
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||
MatPtrT(const std::string& name, const TensorInfo* tensor)
|
||||
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
|
||||
if (tensor == nullptr) {
|
||||
cols_ = 0;
|
||||
rows_ = 0;
|
||||
} else {
|
||||
cols_ = tensor->shape.back();
|
||||
rows_ = 1;
|
||||
if (tensor->cols_take_extra_dims) {
|
||||
// The columns eat the extra dimensions.
|
||||
rows_ = tensor->shape[0];
|
||||
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
||||
cols_ *= tensor->shape[i];
|
||||
}
|
||||
} else {
|
||||
// The rows eat the extra dimensions.
|
||||
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
||||
rows_ *= tensor->shape[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
stride_ = cols_;
|
||||
num_elements_ = rows_ * cols_;
|
||||
}
|
||||
|
||||
// Copying allowed as the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
MatPtr::operator=(other);
|
||||
return *this;
|
||||
}
|
||||
MatPtrT(const MatPtrT& other) = default;
|
||||
MatPtrT& operator=(const MatPtrT& other) = default;
|
||||
|
||||
std::string CacheName(int layer = -1, char separator = ' ',
|
||||
int index = -1) const {
|
||||
// Already used/retired: s, S, n, 1
|
||||
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
|
||||
: hwy::IsSame<MatT, BF16>() ? 'B'
|
||||
: hwy::IsSame<MatT, SfpStream>() ? '$'
|
||||
: hwy::IsSame<MatT, NuqStream>() ? '2'
|
||||
: '?';
|
||||
std::string name = std::string(1, prefix) + name_;
|
||||
if (layer >= 0 || index >= 0) {
|
||||
name += '_';
|
||||
if (layer >= 0) name += std::to_string(layer);
|
||||
if (index >= 0) {
|
||||
name += separator + std::to_string(index);
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
// Sets the number of elements in the array. For use when the number of
|
||||
// elements is != rows * cols ONLY.
|
||||
void SetNumElements(size_t num_elements) {
|
||||
num_elements_ = CompressedArrayElements<MatT>(num_elements);
|
||||
}
|
||||
|
||||
// 2-d Accessor for a specific type but with a dynamic inner dimension.
|
||||
template <typename T = MatT>
|
||||
const T& At(size_t row, size_t col) const {
|
||||
size_t index = row * cols_ + col;
|
||||
HWY_DASSERT(index < num_elements_);
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_)[index];
|
||||
}
|
||||
|
||||
// 1-d Accessor for a specific type.
|
||||
// TODO: replace this with a Foreach(), or at least a ForEachRow().
|
||||
const MatT& At(size_t index) const {
|
||||
HWY_DASSERT(index < num_elements_);
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index];
|
||||
}
|
||||
MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; }
|
||||
|
||||
// Compatibility interface for CompressedArray.
|
||||
// TODO: remove
|
||||
template <typename T = MatT>
|
||||
T* data() {
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_);
|
||||
}
|
||||
template <typename T = MatT>
|
||||
const T* data() const {
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_);
|
||||
}
|
||||
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
|
||||
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
|
||||
// A scale of 0 indicates that the scale has likely never been set, so is
|
||||
// "implicitly 1".
|
||||
const MatT* data_scale1() const {
|
||||
HWY_ASSERT(scale() == 1.f);
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
||||
}
|
||||
};
|
||||
|
||||
template <class FuncT, typename... TArgs>
|
||||
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
|
||||
if (type_ == TypeEnum<float>()) {
|
||||
return func(dynamic_cast<MatPtrT<float>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<BF16>()) {
|
||||
return func(dynamic_cast<MatPtrT<BF16>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<SfpStream>()) {
|
||||
return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else if (type_ == TypeEnum<NuqStream>()) {
|
||||
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
|
||||
std::forward<TArgs>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Type %d unknown.", type_);
|
||||
}
|
||||
}
|
||||
|
||||
// MatStorageT adds the actual data storage to MatPtrT.
|
||||
// TODO: use Extents2D instead of rows and cols.
|
||||
template <typename MatT>
|
||||
class MatStorageT : public MatPtrT<MatT> {
|
||||
public:
|
||||
// Full constructor for dynamic sizing.
|
||||
MatStorageT(const std::string& name, size_t rows, size_t cols)
|
||||
: MatPtrT<MatT>(name, rows, cols) {
|
||||
Allocate();
|
||||
}
|
||||
// Can copy the metadata, from a MatPtr, and allocate later.
|
||||
MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {}
|
||||
~MatStorageT() = default;
|
||||
|
||||
// Move-only because this contains a unique_ptr.
|
||||
MatStorageT(const MatStorageT& other) = delete;
|
||||
MatStorageT& operator=(const MatStorageT& other) = delete;
|
||||
MatStorageT(MatStorageT&& other) = default;
|
||||
MatStorageT& operator=(MatStorageT&& other) = default;
|
||||
|
||||
// Allocate the memory and copy the pointer to the MatPtr.
|
||||
// num_elements is in elements. In the default (zero) case, it is computed
|
||||
// from the current num_elements_ which was set by the constructor from the
|
||||
// rows and cols.
|
||||
void Allocate(size_t num_elements = 0) {
|
||||
if (num_elements == 0) {
|
||||
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
|
||||
} else {
|
||||
this->num_elements_ = num_elements;
|
||||
}
|
||||
// Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
|
||||
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
|
||||
const size_t padding = hwy::VectorBytes();
|
||||
data_ = Allocator::Alloc<MatT>(num_elements + padding);
|
||||
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
|
||||
this->ptr_ = data_.get();
|
||||
}
|
||||
|
||||
// Zeros the content.
|
||||
void ZeroInit() {
|
||||
HWY_ASSERT(data_ != nullptr);
|
||||
hwy::ZeroBytes(data_.get(), this->SizeBytes());
|
||||
}
|
||||
|
||||
private:
|
||||
AlignedPtr<MatT> data_;
|
||||
};
|
||||
|
||||
// MatStorage allows heterogeneous tensors to be stored in a single vector.
|
||||
using MatStorage = MatStorageT<hwy::uint128_t>;
|
||||
|
||||
// Table of contents for a blob store file. Full metadata, but not actual data.
|
||||
class BlobToc {
|
||||
public:
|
||||
|
|
@ -389,7 +74,7 @@ class BlobToc {
|
|||
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
|
||||
prev_consumed = consumed;
|
||||
consumed = result.pos;
|
||||
if (blob.NumElements() > 0) {
|
||||
if (!blob.IsEmpty()) {
|
||||
AddToToc(blob);
|
||||
}
|
||||
}
|
||||
|
|
@ -503,10 +188,11 @@ class WriteToBlobStore {
|
|||
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
|
||||
|
||||
template <typename Packed>
|
||||
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
|
||||
if (compressed->Ptr() == nullptr) return;
|
||||
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
|
||||
compressed->SizeBytes());
|
||||
void operator()(MatPtrT<Packed>* compressed,
|
||||
const char* decorated_name) const {
|
||||
if (!compressed->HasPtr()) return;
|
||||
writer_.Add(MakeKey(decorated_name), compressed->Packed(),
|
||||
compressed->PackedBytes());
|
||||
MatPtr renamed_tensor(*compressed);
|
||||
renamed_tensor.SetName(decorated_name);
|
||||
renamed_tensor.AppendTo(toc_);
|
||||
|
|
@ -519,9 +205,8 @@ class WriteToBlobStore {
|
|||
|
||||
void AddScales(const float* scales, size_t len) {
|
||||
if (len) {
|
||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
|
||||
len * sizeof(scales[0]));
|
||||
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
|
||||
writer_.Add(MakeKey(scales_ptr.Name()), scales, len * sizeof(scales[0]));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -554,9 +239,9 @@ class WriteToBlobStore {
|
|||
hwy::ThreadPool& pool_;
|
||||
|
||||
private:
|
||||
std::vector<uint32_t> toc_;
|
||||
BlobWriter writer_;
|
||||
std::vector<uint32_t> config_buffer_;
|
||||
mutable std::vector<uint32_t> toc_;
|
||||
mutable BlobWriter writer_;
|
||||
mutable std::vector<uint32_t> config_buffer_;
|
||||
};
|
||||
|
||||
// Functor called for each tensor, which loads them and their scaling factors
|
||||
|
|
@ -613,6 +298,7 @@ class ReadFromBlobStore {
|
|||
// Called for each tensor, enqueues read requests.
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
if (file_toc_.Empty() || file_toc_.Contains(name)) {
|
||||
HWY_ASSERT(tensors[0]);
|
||||
model_toc_.push_back(tensors[0]);
|
||||
file_keys_.push_back(name);
|
||||
}
|
||||
|
|
@ -622,15 +308,15 @@ class ReadFromBlobStore {
|
|||
for (size_t i = 0; i < len; ++i) {
|
||||
scales[i] = 1.0f;
|
||||
}
|
||||
MatPtrT<float> scales_ptr("scales", 0, 1);
|
||||
auto key = MakeKey(scales_ptr.CacheName().c_str());
|
||||
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
|
||||
auto key = MakeKey(scales_ptr.Name());
|
||||
if (reader_.BlobSize(key) == 0) return 0;
|
||||
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
|
||||
}
|
||||
|
||||
// Returns whether all tensors are successfully loaded from cache.
|
||||
BlobError ReadAll(hwy::ThreadPool& pool,
|
||||
std::vector<MatStorage>& model_memory) {
|
||||
std::vector<MatOwner>& model_memory) {
|
||||
// reader_ invalid or any Enqueue failed
|
||||
if (err_ != 0) return err_;
|
||||
// Setup the model_memory.
|
||||
|
|
@ -650,26 +336,27 @@ class ReadFromBlobStore {
|
|||
}
|
||||
std::string name = blob->Name();
|
||||
*blob = *toc_blob;
|
||||
blob->SetName(name);
|
||||
blob->SetName(name.c_str());
|
||||
}
|
||||
model_memory.emplace_back(*blob);
|
||||
model_memory.back().SetName(file_key);
|
||||
model_memory.push_back(MatOwner());
|
||||
}
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_memory.size(),
|
||||
[this, &model_memory](uint64_t task, size_t /*thread*/) {
|
||||
model_memory[task].Allocate();
|
||||
model_toc_[task]->SetPtr(model_memory[task]);
|
||||
model_memory[task].AllocateFor(*model_toc_[task],
|
||||
MatPadding::kPacked);
|
||||
});
|
||||
// Enqueue the read requests.
|
||||
for (auto& blob : model_memory) {
|
||||
err_ =
|
||||
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
|
||||
for (size_t b = 0; b < model_toc_.size(); ++b) {
|
||||
err_ = reader_.Enqueue(MakeKey(file_keys_[b].c_str()),
|
||||
model_toc_[b]->RowT<uint8_t>(0),
|
||||
model_toc_[b]->PackedBytes());
|
||||
if (err_ != 0) {
|
||||
fprintf(stderr,
|
||||
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
|
||||
blob.Name(), err_, blob.Rows(), blob.Cols(),
|
||||
blob.ElementSize());
|
||||
fprintf(
|
||||
stderr,
|
||||
"Failed to read blob %s (error %d) of size %zu x %zu, type %d\n",
|
||||
file_keys_[b].c_str(), err_, model_toc_[b]->Rows(),
|
||||
model_toc_[b]->Cols(), static_cast<int>(model_toc_[b]->GetType()));
|
||||
return err_;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,286 +0,0 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Command line tool to create compressed weights.
|
||||
|
||||
// Compiles this file for multiple architectures via "foreach_target.h", to
|
||||
// which we pass the filename via macro 'argument'.
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
"compression/compress_weights.cc" // NOLINT
|
||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
|
||||
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
#define GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::clamp
|
||||
#include <cstdlib>
|
||||
#include <iostream>
|
||||
#include <string>
|
||||
#include <thread> // NOLINT
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "gemma/common.h" // Model
|
||||
#include "gemma/weights.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
namespace {
|
||||
|
||||
} // namespace
|
||||
|
||||
struct Args : public ArgsBase<Args> {
|
||||
static constexpr size_t kDefaultNumThreads = ~size_t{0};
|
||||
|
||||
void ChooseNumThreads() {
|
||||
if (num_threads == kDefaultNumThreads) {
|
||||
// This is a rough heuristic, replace with something better in the future.
|
||||
num_threads = static_cast<size_t>(std::clamp(
|
||||
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
|
||||
}
|
||||
}
|
||||
|
||||
public:
|
||||
Args(int argc, char* argv[]) {
|
||||
InitAndParse(argc, argv);
|
||||
ChooseNumThreads();
|
||||
}
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
|
||||
prompt_wrapping_)) {
|
||||
return err;
|
||||
}
|
||||
if (const char* err = ParseType(weight_type_str, weight_type_)) {
|
||||
return err;
|
||||
}
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the uncompressed model.";
|
||||
}
|
||||
if (compressed_weights.path.empty()) {
|
||||
return "Missing --compressed_weights flag, a file for the compressed "
|
||||
"model.";
|
||||
}
|
||||
if (!weights.Exists()) {
|
||||
return "Can't open file specified with --weights flag.";
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Path weights; // uncompressed weights file location
|
||||
Path compressed_weights; // compressed weights file location
|
||||
std::string model_type_str;
|
||||
std::string weight_type_str;
|
||||
size_t num_threads;
|
||||
// If non-empty, whether to include the config and TOC in the output file, as
|
||||
// well as the tokenizer.
|
||||
Path tokenizer;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(weights, "weights", Path(),
|
||||
"Path to model weights (.bin) file.\n"
|
||||
" Required argument.");
|
||||
visitor(model_type_str, "model", std::string(),
|
||||
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
|
||||
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
|
||||
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
|
||||
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
|
||||
"gr2b-pt = griffin 2B parameters, pretrained\n "
|
||||
" Required argument.");
|
||||
visitor(weight_type_str, "weight_type", std::string("sfp"),
|
||||
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
|
||||
" Required argument.");
|
||||
visitor(compressed_weights, "compressed_weights", Path(),
|
||||
"Path name where compressed weights (.sbs) file will be written.\n"
|
||||
" Required argument.");
|
||||
visitor(num_threads, "num_threads",
|
||||
kDefaultNumThreads, // see ChooseNumThreads
|
||||
"Number of threads to use.\n Default = Estimate of the "
|
||||
"number of supported concurrent threads.",
|
||||
2);
|
||||
visitor(tokenizer, "tokenizer", Path(),
|
||||
"Path to tokenizer file. If given, the config and TOC are also "
|
||||
"added to the output file.");
|
||||
}
|
||||
|
||||
// Uninitialized before Validate, must call after that.
|
||||
gcpp::Model ModelType() const { return model_type_; }
|
||||
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
|
||||
gcpp::Type WeightType() const { return weight_type_; }
|
||||
|
||||
private:
|
||||
Model model_type_;
|
||||
PromptWrapping prompt_wrapping_;
|
||||
Type weight_type_;
|
||||
};
|
||||
|
||||
void ShowHelp(gcpp::Args& args) {
|
||||
std::cerr
|
||||
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
|
||||
" --model <model type> --compressed_weights <output path>\n";
|
||||
std::cerr << "\n*Arguments*\n\n";
|
||||
args.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
|
||||
|
||||
// SIMD code, compiled once per target.
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
template <typename T>
|
||||
void CompressWeights(const Path& weights_path,
|
||||
const Path& compressed_weights_path, Model model_type,
|
||||
Type weight_type, PromptWrapping wrapping,
|
||||
const Path& tokenizer_path, hwy::ThreadPool& pool) {
|
||||
if (!weights_path.Exists()) {
|
||||
HWY_ABORT("The model weights file '%s' does not exist.",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
|
||||
compressed_weights_path.path.c_str());
|
||||
ModelConfig config = ConfigFromModel(model_type);
|
||||
config.weight = weight_type;
|
||||
config.wrapping = wrapping;
|
||||
std::vector<MatStorage> model_storage;
|
||||
ModelWeightsPtrs<T> c_weights(config);
|
||||
c_weights.Allocate(model_storage, pool);
|
||||
ModelWeightsPtrs<float> uc_weights(config);
|
||||
uc_weights.Allocate(model_storage, pool);
|
||||
// Get uncompressed weights, compress, and store.
|
||||
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
|
||||
if (fptr == nullptr) {
|
||||
HWY_ABORT("Failed to open model file %s - does it exist?",
|
||||
weights_path.path.c_str());
|
||||
}
|
||||
bool ok = true;
|
||||
uint64_t total_size = 0;
|
||||
ModelWeightsPtrs<float>::ForEachTensor(
|
||||
{&uc_weights}, ForEachType::kLoadNoToc,
|
||||
[&](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
fprintf(stderr, "Loading Parameters (size %zu): %s\n",
|
||||
tensors[0]->SizeBytes(), name);
|
||||
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
|
||||
total_size += tensors[0]->SizeBytes();
|
||||
});
|
||||
if (!tokenizer_path.path.empty()) {
|
||||
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
|
||||
}
|
||||
const bool scale_for_compression = config.num_tensor_scales > 0;
|
||||
std::vector<float> scales;
|
||||
if (scale_for_compression) {
|
||||
uc_weights.GetOrApplyScales(scales);
|
||||
}
|
||||
Compressor compressor(pool);
|
||||
ModelWeightsPtrs<T>::ForEachTensor(
|
||||
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
|
||||
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
|
||||
: ForEachType::kLoadWithToc,
|
||||
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[1]->CallUpcasted(
|
||||
compressor, name,
|
||||
reinterpret_cast<const float*>(tensors[0]->Ptr()));
|
||||
});
|
||||
if (!tokenizer_path.path.empty()) {
|
||||
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
|
||||
compressor.AddTokenizer(tokenizer_proto);
|
||||
} else {
|
||||
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
|
||||
}
|
||||
compressor.WriteAll(compressed_weights_path,
|
||||
tokenizer_path.path.empty() ? nullptr : &config);
|
||||
}
|
||||
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
HWY_AFTER_NAMESPACE();
|
||||
|
||||
#if HWY_ONCE
|
||||
namespace gcpp {
|
||||
|
||||
void Run(Args& args) {
|
||||
hwy::ThreadPool pool(args.num_threads);
|
||||
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
|
||||
HWY_ABORT("PaliGemma is not supported in compress_weights.");
|
||||
}
|
||||
const Model model_type = args.ModelType();
|
||||
const Type weight_type = args.WeightType();
|
||||
switch (weight_type) {
|
||||
case Type::kF32:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kBF16:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kSFP:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
case Type::kNUQ:
|
||||
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
|
||||
(args.weights, args.compressed_weights, model_type, weight_type,
|
||||
args.PromptWrappingType(), args.tokenizer, pool);
|
||||
break;
|
||||
default:
|
||||
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
gcpp::Args args(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
gcpp::ShowHelp(args);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = args.Validate()) {
|
||||
gcpp::ShowHelp(args);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(args);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
#endif // HWY_ONCE
|
||||
|
|
@ -57,6 +57,6 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
gcpp::GemmaEnv env(argc, argv);
|
||||
hwy::ThreadPool pool(0);
|
||||
env.GetModel()->Save(args.output_weights, pool);
|
||||
env.GetGemma()->Save(args.output_weights, pool);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,7 +16,9 @@ cc_library(
|
|||
deps = [
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:common",
|
||||
"//:mat",
|
||||
"//:tokenizer",
|
||||
"//:weights",
|
||||
"//compression:compress",
|
||||
"//compression:io",
|
||||
"@highway//:hwy",
|
||||
|
|
@ -30,7 +32,6 @@ pybind_extension(
|
|||
deps = [
|
||||
":compression_clif_aux",
|
||||
"@abseil-cpp//absl/types:span",
|
||||
"//:common",
|
||||
"//compression:sfp",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -22,7 +22,8 @@
|
|||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "util/mat.h"
|
||||
|
||||
#undef HWY_TARGET_INCLUDE
|
||||
#define HWY_TARGET_INCLUDE \
|
||||
|
|
@ -81,30 +82,23 @@ class SbsWriterImpl : public WriterInterface {
|
|||
template <typename Packed>
|
||||
void AllocateAndCompress(const std::string& name,
|
||||
absl::Span<const float> weights) {
|
||||
MatPtrT<Packed> storage(name, 1, weights.size());
|
||||
model_memory_.push_back(storage);
|
||||
model_memory_.back().Allocate();
|
||||
storage.SetPtr(model_memory_.back());
|
||||
std::string decorated_name = storage.CacheName();
|
||||
MatPtrT<Packed> storage(name.c_str(), Extents2D(1, weights.size()));
|
||||
model_memory_.push_back(MatOwner());
|
||||
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
|
||||
std::string decorated_name = CacheName(storage);
|
||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||
}
|
||||
template <typename Packed>
|
||||
void AllocateWithShape(const std::string& name,
|
||||
absl::Span<const float> weights,
|
||||
const TensorInfo& tensor_info, float scale) {
|
||||
MatPtrT<Packed> storage(name, &tensor_info);
|
||||
storage.set_scale(scale);
|
||||
MatPtrT<Packed> storage(name.c_str(), &tensor_info);
|
||||
storage.SetScale(scale);
|
||||
|
||||
// Don't reset num_elements for NUQ.
|
||||
if (!hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||
storage.SetNumElements(CompressedArrayElements<Packed>(weights.size()));
|
||||
}
|
||||
|
||||
model_memory_.push_back(storage);
|
||||
model_memory_.push_back(MatOwner());
|
||||
if (mode_ == CompressorMode::kTEST_ONLY) return;
|
||||
model_memory_.back().Allocate();
|
||||
storage.SetPtr(model_memory_.back());
|
||||
std::string decorated_name = storage.CacheName();
|
||||
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
|
||||
std::string decorated_name = CacheName(storage);
|
||||
compressor_(&storage, decorated_name.c_str(), weights.data());
|
||||
}
|
||||
|
||||
|
|
@ -176,7 +170,7 @@ class SbsWriterImpl : public WriterInterface {
|
|||
hwy::ThreadPool pool_;
|
||||
Compressor compressor_;
|
||||
CompressWorkingSet working_set_;
|
||||
std::vector<MatStorage> model_memory_;
|
||||
std::vector<MatOwner> model_memory_;
|
||||
std::vector<float> scales_;
|
||||
CompressorMode mode_;
|
||||
};
|
||||
|
|
|
|||
|
|
@ -201,15 +201,24 @@ inline bool EnumValid(PromptWrapping type) {
|
|||
|
||||
// Tensor types for loading weights. Note that not all types are supported as
|
||||
// weights for a model, but can be used for other purposes, such as types for
|
||||
// ModelWeightsPtrs. When adding a new type that is supported, also
|
||||
// `WeightsPtrs`. When adding a new type that is supported, also
|
||||
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
|
||||
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
|
||||
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
|
||||
"nuq", "f64", "c64", "u128"};
|
||||
static constexpr const char* kTypeStrings[] = {
|
||||
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"};
|
||||
static constexpr size_t kNumTypes =
|
||||
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
|
||||
static constexpr size_t kTypeBits[] = {0,
|
||||
8 * sizeof(float),
|
||||
8 * sizeof(BF16),
|
||||
8 * sizeof(SfpStream),
|
||||
4 /* NuqStream, actually 4.5 */,
|
||||
8 * sizeof(double),
|
||||
8 * sizeof(std::complex<double>),
|
||||
8 * sizeof(hwy::uint128_t)};
|
||||
|
||||
inline bool EnumValid(Type type) {
|
||||
return static_cast<int>(type) >= 0 &&
|
||||
static_cast<int>(type) <= static_cast<int>(Type::kU128);
|
||||
static inline bool EnumValid(Type type) {
|
||||
return static_cast<size_t>(type) < kNumTypes;
|
||||
}
|
||||
|
||||
// Returns a Type enum for the type of the template parameter.
|
||||
|
|
@ -236,10 +245,16 @@ Type TypeEnum() {
|
|||
}
|
||||
}
|
||||
|
||||
// Returns a string name for the type of the template parameter.
|
||||
static inline size_t TypeBits(Type type) {
|
||||
return kTypeBits[static_cast<int>(type)];
|
||||
}
|
||||
|
||||
static inline const char* TypeName(Type type) {
|
||||
return kTypeStrings[static_cast<int>(type)];
|
||||
}
|
||||
template <typename PackedT>
|
||||
const char* TypeName() {
|
||||
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
|
||||
return TypeName(TypeEnum<PackedT>());
|
||||
}
|
||||
|
||||
template <typename Packed>
|
||||
|
|
@ -248,7 +263,9 @@ constexpr bool IsCompressed() {
|
|||
}
|
||||
|
||||
// Returns the number of `MatT` elements required to store `capacity` values,
|
||||
// which must not be zero.
|
||||
// which must not be zero. This is only intended to support the extra tables
|
||||
// required for NUQ. `capacity` includes any padding and is `rows * stride`.
|
||||
// Deprecated, replaced by fixup within `MatPtr`. Only used by tests.
|
||||
template <typename Packed>
|
||||
constexpr size_t CompressedArrayElements(size_t capacity) {
|
||||
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
|
||||
|
|
|
|||
|
|
@ -18,10 +18,13 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/compress.h"
|
||||
#include "compression/distortion.h"
|
||||
#include "util/mat.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
|
||||
|
||||
// Include guard for (potentially) SIMD code.
|
||||
|
|
@ -62,6 +65,52 @@ void ForeachPackedAndRawType() {
|
|||
ForeachRawType<NuqStream, TestT>();
|
||||
}
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("mat", extents, MatPadding::kPacked);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
row[c] = f;
|
||||
}
|
||||
});
|
||||
|
||||
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
|
||||
return compressed;
|
||||
}
|
||||
|
||||
// `extents` describes the transposed matrix.
|
||||
template <typename MatT>
|
||||
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
|
||||
MatStorageT<MatT> compressed("trans", extents, MatPadding::kPacked);
|
||||
const float scale = SfpStream::kMax / extents.Area();
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
float* HWY_RESTRICT row = raw.Row(r);
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
row[c] = f;
|
||||
}
|
||||
});
|
||||
|
||||
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
|
||||
/*packed_ofs=*/0, pool);
|
||||
// Arbitrary value, different from 1, must match `GenerateMat`.
|
||||
compressed.SetScale(0.6f);
|
||||
return compressed;
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
} // namespace HWY_NAMESPACE
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -128,10 +128,10 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
|
|||
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
|
||||
std::vector<int> prompt_slice(prompt.begin() + pos,
|
||||
prompt.begin() + pos + num_tokens);
|
||||
KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
|
||||
KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(),
|
||||
env.MutableConfig().prefill_tbatch_size);
|
||||
float entropy = ComputeCrossEntropy(
|
||||
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
|
||||
total_entropy += entropy;
|
||||
LogSpeedStats(time_start, pos + num_tokens);
|
||||
std::string text_slice = env.StringFromTokens(prompt_slice);
|
||||
|
|
@ -186,8 +186,8 @@ int main(int argc, char** argv) {
|
|||
if (!benchmark_args.goldens.Empty()) {
|
||||
const std::string golden_path =
|
||||
benchmark_args.goldens.path + "/" +
|
||||
gcpp::ModelString(env.GetModel()->Info().model,
|
||||
env.GetModel()->Info().wrapping) +
|
||||
gcpp::ModelString(env.GetGemma()->Info().model,
|
||||
env.GetGemma()->Info().wrapping) +
|
||||
".txt";
|
||||
return BenchmarkGoldens(env, golden_path);
|
||||
} else if (!benchmark_args.summarize_text.Empty()) {
|
||||
|
|
|
|||
|
|
@ -18,27 +18,20 @@
|
|||
#include <stdio.h>
|
||||
#include <time.h>
|
||||
|
||||
#include <cstdio>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
#include <ostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/compress.h" // TypeName
|
||||
#include "compression/shared.h" // TypeName
|
||||
#include "evals/cross_entropy.h"
|
||||
#include "gemma/common.h" // StringFromType
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/kv_cache.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/topology.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
#include "hwy/per_target.h" // DispatchedTarget
|
||||
#include "hwy/timer.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -54,11 +47,9 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
|
|||
}
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app)
|
||||
: topology_(CreateTopology(app)),
|
||||
pools_(CreatePools(topology_, app)),
|
||||
env_(topology_, pools_) {
|
||||
GemmaEnv::GemmaEnv(const ThreadingArgs& threading_args,
|
||||
const LoaderArgs& loader, const InferenceArgs& inference)
|
||||
: env_(MakeMatMulEnv(threading_args)) {
|
||||
InferenceArgs mutable_inference = inference;
|
||||
AbortIfInvalidArgs(mutable_inference);
|
||||
LoaderArgs mutable_loader = loader;
|
||||
|
|
@ -67,10 +58,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
fprintf(stderr, "Skipping model load because: %s\n", err);
|
||||
} else {
|
||||
fprintf(stderr, "Loading model...\n");
|
||||
model_ = AllocateGemma(mutable_loader, env_);
|
||||
gemma_ = AllocateGemma(mutable_loader, env_);
|
||||
// Only allocate one for starters because GenerateBatch might not be called.
|
||||
kv_caches_.resize(1);
|
||||
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
|
||||
kv_caches_[0] = KVCache::Create(gemma_->GetModelConfig(),
|
||||
inference.prefill_tbatch_size);
|
||||
}
|
||||
InitGenerator(inference, gen_);
|
||||
|
|
@ -78,24 +69,13 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
|||
.max_generated_tokens = inference.max_generated_tokens,
|
||||
.temperature = inference.temperature,
|
||||
.gen = &gen_,
|
||||
.verbosity = app.verbosity,
|
||||
.verbosity = inference.verbosity,
|
||||
};
|
||||
}
|
||||
|
||||
// Internal init must run before the GemmaEnv ctor above, hence it cannot occur
|
||||
// in the argv ctor below because its body runs *after* the delegating ctor.
|
||||
// This helper function takes care of the init, and could be applied to any of
|
||||
// the *Args classes, it does not matter which.
|
||||
static AppArgs MakeAppArgs(int argc, char** argv) {
|
||||
{ // So that indentation matches expectations.
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
return AppArgs(argc, argv);
|
||||
}
|
||||
|
||||
GemmaEnv::GemmaEnv(int argc, char** argv)
|
||||
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
|
||||
MakeAppArgs(argc, argv)) {}
|
||||
: GemmaEnv(ThreadingArgs(argc, argv), LoaderArgs(argc, argv),
|
||||
InferenceArgs(argc, argv)) {}
|
||||
|
||||
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
||||
QueryResult result;
|
||||
|
|
@ -117,7 +97,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
|
|||
}
|
||||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
return result;
|
||||
}
|
||||
|
|
@ -127,7 +107,7 @@ void GemmaEnv::QueryModel(
|
|||
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
|
||||
const StreamFunc previous_stream_token = runtime_config_.stream_token;
|
||||
runtime_config_.stream_token = stream_token;
|
||||
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
|
||||
timing_info);
|
||||
runtime_config_.stream_token = previous_stream_token;
|
||||
}
|
||||
|
|
@ -142,7 +122,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
int token, float) {
|
||||
std::string token_text;
|
||||
HWY_ASSERT(
|
||||
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
gemma_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
res[query_index].response.append(token_text);
|
||||
res[query_index].tokens_generated += 1;
|
||||
if (res[query_index].tokens_generated ==
|
||||
|
|
@ -164,7 +144,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
}
|
||||
for (size_t i = 1; i < num_queries; ++i) {
|
||||
if (kv_caches_[i].seq_len == 0) {
|
||||
kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
|
||||
kv_caches_[i] = KVCache::Create(gemma_->GetModelConfig(),
|
||||
runtime_config_.prefill_tbatch_size);
|
||||
}
|
||||
}
|
||||
|
|
@ -172,7 +152,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
|
||||
runtime_config_.batch_stream_token = batch_stream_token;
|
||||
std::vector<size_t> queries_pos(num_queries, 0);
|
||||
model_->GenerateBatch(runtime_config_, queries_prompt,
|
||||
gemma_->GenerateBatch(runtime_config_, queries_prompt,
|
||||
QueriesPos(queries_pos.data(), num_queries),
|
||||
KVCaches(&kv_caches_[0], num_queries), timing_info);
|
||||
return res;
|
||||
|
|
@ -203,7 +183,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
|
|||
float GemmaEnv::CrossEntropy(const std::string& input) {
|
||||
std::vector<int> prompt = Tokenize(input);
|
||||
prompt.insert(prompt.begin(), BOS_ID);
|
||||
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
|
||||
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
|
||||
MutableKVCache(),
|
||||
/*verbosity=*/0) /
|
||||
static_cast<int>(input.size());
|
||||
|
|
@ -236,17 +216,36 @@ std::string CacheString() {
|
|||
return buf;
|
||||
}
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
const BoundedTopology& topology, NestedPools& pools) {
|
||||
loader.Print(app.verbosity);
|
||||
inference.Print(app.verbosity);
|
||||
app.Print(app.verbosity);
|
||||
static constexpr const char* CompiledConfig() {
|
||||
if constexpr (HWY_IS_ASAN) {
|
||||
return "asan";
|
||||
} else if constexpr (HWY_IS_MSAN) {
|
||||
return "msan";
|
||||
} else if constexpr (HWY_IS_TSAN) {
|
||||
return "tsan";
|
||||
} else if constexpr (HWY_IS_HWASAN) {
|
||||
return "hwasan";
|
||||
} else if constexpr (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
} else if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
return "opt";
|
||||
}
|
||||
}
|
||||
|
||||
if (app.verbosity >= 2) {
|
||||
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference) {
|
||||
threading.Print(inference.verbosity);
|
||||
loader.Print(inference.verbosity);
|
||||
inference.Print(inference.verbosity);
|
||||
|
||||
if (inference.verbosity >= 2) {
|
||||
time_t now = time(nullptr);
|
||||
char* dt = ctime(&now); // NOLINT
|
||||
char cpu100[100] = "unknown";
|
||||
(void)hwy::platform::GetCpuString(cpu100);
|
||||
const ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||
|
||||
fprintf(stderr,
|
||||
"Date & Time : %s" // dt includes \n
|
||||
|
|
@ -254,16 +253,18 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
|||
"CPU topology : %s, %s, %s\n"
|
||||
"Instruction set : %s (%zu bits)\n"
|
||||
"Compiled config : %s\n"
|
||||
"Weight Type : %s\n"
|
||||
"EmbedderInput Type : %s\n",
|
||||
dt, cpu100, topology.TopologyString(), pools.PinString(),
|
||||
"Memory MiB : %4zu, %4zu free\n"
|
||||
"Weight Type : %s\n",
|
||||
dt, cpu100, ctx.topology.TopologyString(), ctx.pools.PinString(),
|
||||
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
|
||||
hwy::VectorBytes() * 8, CompiledConfig(),
|
||||
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());
|
||||
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
|
||||
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB(),
|
||||
StringFromType(loader.Info().weight));
|
||||
}
|
||||
}
|
||||
|
||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference) {
|
||||
std::cerr
|
||||
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
|
||||
"==========================================================\n\n"
|
||||
|
|
@ -272,16 +273,16 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
" --tokenizer\n"
|
||||
" --weights\n"
|
||||
" --model,\n"
|
||||
" or with the newer weights format, specify just:\n"
|
||||
" or with the single-file weights format, specify just:\n"
|
||||
" --weights\n";
|
||||
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
|
||||
"--weights 2b-it-sfp.sbs --model 2b-it\n";
|
||||
std::cerr << "\n*Threading Arguments*\n\n";
|
||||
threading.Help();
|
||||
std::cerr << "\n*Model Loading Arguments*\n\n";
|
||||
loader.Help();
|
||||
std::cerr << "\n*Inference Arguments*\n\n";
|
||||
inference.Help();
|
||||
std::cerr << "\n*Application Arguments*\n\n";
|
||||
app.Help();
|
||||
std::cerr << "\n";
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -24,9 +24,9 @@
|
|||
#include <vector>
|
||||
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/app.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
|
@ -46,8 +46,10 @@ class GemmaEnv {
|
|||
public:
|
||||
// Calls the other constructor with *Args arguments initialized from argv.
|
||||
GemmaEnv(int argc, char** argv);
|
||||
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
|
||||
const AppArgs& app);
|
||||
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
|
||||
const InferenceArgs& inference);
|
||||
|
||||
MatMulEnv& Env() { return env_; }
|
||||
|
||||
size_t MaxGeneratedTokens() const {
|
||||
return runtime_config_.max_generated_tokens;
|
||||
|
|
@ -58,7 +60,7 @@ class GemmaEnv {
|
|||
|
||||
std::vector<int> Tokenize(const std::string& input) const {
|
||||
std::vector<int> tokens;
|
||||
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
|
||||
HWY_ASSERT(gemma_->Tokenizer().Encode(input, &tokens));
|
||||
return tokens;
|
||||
}
|
||||
|
||||
|
|
@ -69,13 +71,13 @@ class GemmaEnv {
|
|||
}
|
||||
|
||||
std::vector<int> WrapAndTokenize(std::string& input) const {
|
||||
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(),
|
||||
model_->Info(), 0, input);
|
||||
return gcpp::WrapAndTokenize(gemma_->Tokenizer(), gemma_->ChatTemplate(),
|
||||
gemma_->Info(), 0, input);
|
||||
}
|
||||
|
||||
std::string StringFromTokens(const std::vector<int>& tokens) const {
|
||||
std::string string;
|
||||
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
|
||||
HWY_ASSERT(gemma_->Tokenizer().Decode(tokens, &string));
|
||||
return string;
|
||||
}
|
||||
|
||||
|
|
@ -99,7 +101,7 @@ class GemmaEnv {
|
|||
float CrossEntropy(const std::string& input);
|
||||
|
||||
// Returns nullptr if the model failed to load.
|
||||
Gemma* GetModel() const { return model_.get(); }
|
||||
Gemma* GetGemma() const { return gemma_.get(); }
|
||||
|
||||
int Verbosity() const { return runtime_config_.verbosity; }
|
||||
RuntimeConfig& MutableConfig() { return runtime_config_; }
|
||||
|
|
@ -107,11 +109,9 @@ class GemmaEnv {
|
|||
KVCache& MutableKVCache() { return kv_caches_[0]; }
|
||||
|
||||
private:
|
||||
BoundedTopology topology_;
|
||||
NestedPools pools_; // Thread pool.
|
||||
MatMulEnv env_;
|
||||
std::mt19937 gen_; // Random number generator.
|
||||
std::unique_ptr<Gemma> model_;
|
||||
std::unique_ptr<Gemma> gemma_;
|
||||
std::vector<KVCache> kv_caches_; // Same number as query batch.
|
||||
RuntimeConfig runtime_config_;
|
||||
};
|
||||
|
|
@ -119,9 +119,10 @@ class GemmaEnv {
|
|||
// Logs the inference speed in tokens/sec.
|
||||
void LogSpeedStats(double time_start, size_t total_tokens);
|
||||
|
||||
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
|
||||
const BoundedTopology& topology, NestedPools& pools);
|
||||
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
|
||||
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference);
|
||||
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference);
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
|
|
|
|||
|
|
@ -51,8 +51,8 @@ class GemmaTest : public ::testing::Test {
|
|||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
|
|
@ -76,7 +76,7 @@ class GemmaTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
|
|
|
|||
|
|
@ -50,8 +50,8 @@ class GemmaTest : public ::testing::Test {
|
|||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||
std::string mutable_prompt = prompt;
|
||||
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
|
||||
return result.response;
|
||||
|
|
@ -71,8 +71,8 @@ class GemmaTest : public ::testing::Test {
|
|||
// Using the turn structure worsens results sometimes.
|
||||
// However, some models need the turn structure to work.
|
||||
// It would be good to make these tests more consistent.
|
||||
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
|
||||
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
|
||||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
|
||||
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
|
||||
replies.push_back(result.response);
|
||||
}
|
||||
|
|
@ -96,7 +96,7 @@ class GemmaTest : public ::testing::Test {
|
|||
}
|
||||
|
||||
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
if (batch) {
|
||||
std::vector<std::string> inputs;
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
|
|
@ -155,8 +155,8 @@ TEST_F(GemmaTest, Arithmetic) {
|
|||
}
|
||||
|
||||
TEST_F(GemmaTest, Multiturn) {
|
||||
Gemma* model = s_env->GetModel();
|
||||
ASSERT_NE(model, nullptr);
|
||||
Gemma* model = s_env->GetGemma();
|
||||
HWY_ASSERT(model != nullptr);
|
||||
size_t abs_pos = 0;
|
||||
std::string response;
|
||||
auto stream_token = [&](int token, float) {
|
||||
|
|
@ -239,12 +239,12 @@ static const char kGettysburg[] = {
|
|||
"people, for the people, shall not perish from the earth.\n"};
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropySmall) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
static const char kSmall[] =
|
||||
"The capital of Hungary is Budapest which is located in Europe.";
|
||||
float entropy = s_env->CrossEntropy(kSmall);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 2.6f, 0.2f);
|
||||
|
|
@ -272,10 +272,10 @@ TEST_F(GemmaTest, CrossEntropySmall) {
|
|||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
float entropy = s_env->CrossEntropy(kJingleBells);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.9f, 0.2f);
|
||||
|
|
@ -303,10 +303,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
|
|||
}
|
||||
|
||||
TEST_F(GemmaTest, CrossEntropyGettysburg) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
HWY_ASSERT(s_env->GetGemma() != nullptr);
|
||||
float entropy = s_env->CrossEntropy(kGettysburg);
|
||||
fprintf(stderr, "per-token entropy: %f\n", entropy);
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
case gcpp::Model::GEMMA_2B:
|
||||
// 2B v.1 and v.1.1 produce slightly different results.
|
||||
EXPECT_NEAR(entropy, 1.1f, 0.1f);
|
||||
|
|
|
|||
|
|
@ -89,7 +89,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
"A", "B", "C", "D", //
|
||||
" A", " B", " C", " D", //
|
||||
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
|
||||
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
|
||||
const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings);
|
||||
|
||||
for (auto sample : json_data["samples"]) {
|
||||
const int id = sample["i"];
|
||||
|
|
@ -131,7 +131,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
|
|||
.verbosity = env.Verbosity(),
|
||||
.stream_token = stream_token,
|
||||
};
|
||||
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
|
||||
env.MutableKVCache(), timing_info);
|
||||
|
||||
std::string output_string = env.StringFromTokens(predicted_token_ids);
|
||||
|
|
|
|||
|
|
@ -10,13 +10,11 @@ cc_binary(
|
|||
name = "hello_world",
|
||||
srcs = ["run.cc"],
|
||||
deps = [
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//:threading",
|
||||
"//:threading_context",
|
||||
"//:tokenizer",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
|||
example:
|
||||
|
||||
```sh
|
||||
./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
||||
./hello_world --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
|
||||
```
|
||||
|
||||
Should print a greeting to the terminal:
|
||||
|
|
|
|||
|
|
@ -23,23 +23,17 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "util/app.h" // LoaderArgs
|
||||
#include "util/args.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
{
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
loader.Help();
|
||||
return 0;
|
||||
|
|
@ -53,14 +47,14 @@ int main(int argc, char** argv) {
|
|||
for (int arg = 0; arg < argc; ++arg) {
|
||||
// Find a --reject flag and consume everything after it.
|
||||
if (strcmp(argv[arg], "--reject") == 0) {
|
||||
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
|
||||
while (++arg < argc) {
|
||||
reject_tokens.insert(atoi(argv[arg])); // NOLINT
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Instantiate model and KV Cache
|
||||
gcpp::BoundedTopology topology(gcpp::CreateTopology(app));
|
||||
gcpp::NestedPools pools = gcpp::CreatePools(topology, app);
|
||||
gcpp::MatMulEnv env(topology, pools);
|
||||
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
|
||||
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
|
||||
gcpp::KVCache kv_cache =
|
||||
gcpp::KVCache::Create(model.GetModelConfig(),
|
||||
|
|
|
|||
|
|
@ -10,10 +10,10 @@ cc_library(
|
|||
name = "gemma",
|
||||
hdrs = ["gemma.hpp"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//:ops",
|
||||
"//:threading",
|
||||
"//:threading_context",
|
||||
"//:tokenizer",
|
||||
"@highway//:hwy",
|
||||
],
|
||||
|
|
@ -24,15 +24,6 @@ cc_binary(
|
|||
srcs = ["run.cc"],
|
||||
deps = [
|
||||
":gemma",
|
||||
# Placeholder for internal dep, do not remove.,
|
||||
"//:app",
|
||||
"//:args",
|
||||
"//:common",
|
||||
"//:gemma_lib",
|
||||
"//:ops",
|
||||
"//:threading",
|
||||
"//:tokenizer",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
"//:gemma_args",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
|
|||
example:
|
||||
|
||||
```sh
|
||||
./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
|
||||
./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
|
||||
```
|
||||
|
||||
Should print a greeting to the terminal:
|
||||
|
|
|
|||
|
|
@ -24,39 +24,22 @@
|
|||
#include <vector>
|
||||
|
||||
#include "third_party/gemma_cpp/gemma/gemma.h"
|
||||
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
|
||||
#include "third_party/gemma_cpp/gemma/tokenizer.h"
|
||||
#include "third_party/gemma_cpp/ops/matmul.h"
|
||||
#include "third_party/gemma_cpp/util/app.h" // LoaderArgs
|
||||
#include "third_party/gemma_cpp/util/threading.h"
|
||||
#include "third_party/gemma_cpp/util/threading_context.h"
|
||||
#include "third_party/highway/hwy/base.h"
|
||||
|
||||
class SimplifiedGemma {
|
||||
public:
|
||||
SimplifiedGemma(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(),
|
||||
const gcpp::AppArgs& app = gcpp::AppArgs())
|
||||
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
|
||||
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
|
||||
: loader_(loader),
|
||||
threading_(threading),
|
||||
inference_(inference),
|
||||
app_(app),
|
||||
topology_(gcpp::CreateTopology(app_)),
|
||||
pools_(gcpp::CreatePools(topology_, app_)),
|
||||
env_(topology_, pools_),
|
||||
env_(MakeMatMulEnv(threading_)),
|
||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: loader_(argc, argv, /*validate=*/true),
|
||||
inference_(argc, argv),
|
||||
app_(argc, argv),
|
||||
topology_(gcpp::CreateTopology(app_)),
|
||||
pools_(gcpp::CreatePools(topology_, app_)),
|
||||
env_(topology_, pools_),
|
||||
model_(gcpp::CreateGemma(loader_, env_)) {
|
||||
Init();
|
||||
}
|
||||
|
||||
void Init() {
|
||||
// Instantiate model and KV Cache
|
||||
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
|
||||
inference_.prefill_tbatch_size);
|
||||
|
|
@ -66,6 +49,11 @@ class SimplifiedGemma {
|
|||
gen_.seed(rd());
|
||||
}
|
||||
|
||||
SimplifiedGemma(int argc, char** argv)
|
||||
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv, /*validate=*/true),
|
||||
gcpp::ThreadingArgs(argc, argv),
|
||||
gcpp::InferenceArgs(argc, argv)) {}
|
||||
|
||||
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
|
||||
float temperature = 0.7,
|
||||
const std::set<int>& reject_tokens = {}) {
|
||||
|
|
@ -107,10 +95,8 @@ class SimplifiedGemma {
|
|||
|
||||
private:
|
||||
gcpp::LoaderArgs loader_;
|
||||
gcpp::ThreadingArgs threading_;
|
||||
gcpp::InferenceArgs inference_;
|
||||
gcpp::AppArgs app_;
|
||||
gcpp::BoundedTopology topology_;
|
||||
gcpp::NestedPools pools_;
|
||||
gcpp::MatMulEnv env_;
|
||||
gcpp::Gemma model_;
|
||||
gcpp::KVCache kv_cache_;
|
||||
|
|
|
|||
|
|
@ -17,15 +17,10 @@
|
|||
|
||||
#include <string>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
|
||||
#include "util/app.h" // LoaderArgs
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
|
||||
int main(int argc, char** argv) {
|
||||
{
|
||||
// Placeholder for internal init, do not modify.
|
||||
}
|
||||
|
||||
// Standard usage: LoaderArgs takes argc and argv as input, then parses
|
||||
// necessary flags.
|
||||
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
|
||||
|
|
@ -35,12 +30,12 @@ int main(int argc, char** argv) {
|
|||
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
|
||||
// "model_identifier");
|
||||
|
||||
// Optional: InferenceArgs and AppArgs can be passed in as well. If not
|
||||
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
|
||||
// specified, default values will be used.
|
||||
//
|
||||
// gcpp::InferenceArgs inference(argc, argv);
|
||||
// gcpp::AppArgs app(argc, argv);
|
||||
// SimplifiedGemma gemma(loader, inference, app);
|
||||
// gcpp::ThreadingArgs threading(argc, argv);
|
||||
// SimplifiedGemma gemma(loader, threading, inference);
|
||||
|
||||
SimplifiedGemma gemma(loader);
|
||||
std::string prompt = "Write a greeting to the world.";
|
||||
|
|
|
|||
|
|
@ -18,14 +18,12 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "compression/shared.h" // BF16
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/configs.h" // ModelConfig
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "ops/ops.h" // CreateInvTimescale
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_DASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "util/allocator.h" // Allocator
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
|
|
@ -74,6 +72,8 @@ struct Activations {
|
|||
size_t cache_pos_size = 0;
|
||||
|
||||
void Allocate(size_t batch_size, MatMulEnv* env) {
|
||||
const Allocator2& allocator = env->ctx.allocator;
|
||||
|
||||
post_qk = layer_config.post_qk;
|
||||
const size_t model_dim = weights_config.model_dim;
|
||||
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
|
||||
|
|
@ -81,36 +81,45 @@ struct Activations {
|
|||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
|
||||
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
x = RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
q = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, heads * layer_config.QStride()));
|
||||
allocator, Extents2D(batch_size, heads * layer_config.QStride()));
|
||||
if (vocab_size > 0) {
|
||||
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
|
||||
logits =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, vocab_size));
|
||||
}
|
||||
|
||||
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
pre_att_rms_out =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
att = RowVectorBatch<float>(
|
||||
Extents2D(batch_size, heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
|
||||
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
allocator, Extents2D(batch_size, heads * weights_config.seq_len));
|
||||
att_out = RowVectorBatch<float>(allocator,
|
||||
Extents2D(batch_size, heads * qkv_dim));
|
||||
att_sums =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
|
||||
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
|
||||
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
|
||||
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
bf_pre_ffw_rms_out =
|
||||
RowVectorBatch<BF16>(allocator, Extents2D(batch_size, model_dim));
|
||||
C1 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
|
||||
C2 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
|
||||
ffw_out =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
|
||||
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
|
||||
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
griffin_x =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
griffin_y =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
griffin_gate_x =
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
griffin_multiplier =
|
||||
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
|
||||
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
|
||||
}
|
||||
|
||||
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
|
||||
inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim,
|
||||
post_qk == PostQKType::HalfRope);
|
||||
inv_timescale_global =
|
||||
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||
inv_timescale_global = CreateInvTimescale(
|
||||
allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
|
||||
|
||||
this->env = env;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,13 +17,13 @@
|
|||
|
||||
#include <math.h> // sqrtf
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::min
|
||||
#include <cstdio>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "gemma/activations.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
|
|
@ -32,11 +32,9 @@
|
|||
#include "gemma/weights.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/bit_set.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
|
@ -82,7 +80,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
const KVCaches& kv_caches) {
|
||||
PROFILER_ZONE("Gen.Griffin");
|
||||
KVCache& kv_cache = kv_caches[0];
|
||||
hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0);
|
||||
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
using D = hn::ScalableTag<float>;
|
||||
const size_t model_dim = layer_weights->layer_config.model_dim;
|
||||
|
|
@ -96,8 +94,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
|
||||
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
|
||||
activations.pre_att_rms_out.Batch(batch_idx),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(),
|
||||
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
|
||||
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
|
||||
/*out0=*/x, /*out1=*/y, pool);
|
||||
Gelu(y, model_dim);
|
||||
}
|
||||
|
|
@ -121,15 +119,15 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
|
||||
auto xv = hn::Load(df, x + i);
|
||||
auto accum0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i);
|
||||
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
|
||||
auto accum1 = hn::Zero(df);
|
||||
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
|
||||
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
|
||||
auto wv0 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
(conv_1d_width - 1 - 2 * l) * model_dim + i);
|
||||
auto wv1 =
|
||||
hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
|
||||
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
|
||||
(conv_1d_width - 2 - 2 * l) * model_dim + i);
|
||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||
|
|
@ -156,9 +154,9 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
TwoOfsMatVecAddLoop(
|
||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||
kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset,
|
||||
/*add0=*/layer_weights->griffin.gate_biases.data_scale1() +
|
||||
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
head_offset,
|
||||
/*add1=*/layer_weights->griffin.gate_biases.data_scale1() +
|
||||
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
|
||||
model_dim + head_offset,
|
||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||
|
|
@ -166,7 +164,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||
layer_weights->griffin.a.data_scale1() + head_offset,
|
||||
layer_weights->griffin.a.PackedScale1() + head_offset,
|
||||
fn_mul);
|
||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||
fn_mul);
|
||||
|
|
@ -198,7 +196,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
|
|||
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
|
||||
float* out_ptr = activations.att_sums.Batch(batch_idx);
|
||||
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
|
||||
layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr,
|
||||
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
|
||||
pool);
|
||||
}
|
||||
}
|
||||
|
|
@ -253,7 +251,7 @@ class GemmaAttention {
|
|||
|
||||
const auto pre_att_rms_out =
|
||||
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
|
||||
auto w_q1 = layer_weights_.qkv_einsum_w.data()
|
||||
auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
|
||||
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
|
||||
|
|
@ -265,15 +263,20 @@ class GemmaAttention {
|
|||
const size_t w1_rows = heads * layer_config_.QStride();
|
||||
w_q1.ShrinkRows(w1_rows);
|
||||
MatMul(pre_att_rms_out, w_q1,
|
||||
/*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q));
|
||||
/*add=*/nullptr, *activations_.env,
|
||||
RowPtrFromBatch(allocator_, activations_.q));
|
||||
|
||||
if (is_mha_) {
|
||||
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
|
||||
} else {
|
||||
auto w_q2 = layer_weights_.qkv_einsum_w.data()
|
||||
? ConstMatFromWeights(layer_weights_.qkv_einsum_w,
|
||||
w1_rows * model_dim)
|
||||
: ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
||||
decltype(w_q1) w_q2;
|
||||
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
||||
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w);
|
||||
// Skip first half of the matrix.
|
||||
w_q2.ofs = w_q2.Row(w1_rows);
|
||||
} else {
|
||||
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
|
||||
}
|
||||
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
|
||||
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
|
||||
w_q2.ShrinkRows(w_rows_kv_cols);
|
||||
|
|
@ -285,7 +288,7 @@ class GemmaAttention {
|
|||
const size_t kv_ofs =
|
||||
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
|
||||
RowPtrF kv_rows(kv, w_rows_kv_cols);
|
||||
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
|
||||
kv_rows.SetStride(cache_pos_size_);
|
||||
MatMul(pre_att_rms_out, w_q2,
|
||||
/*add=*/nullptr, *activations_.env, kv_rows);
|
||||
|
|
@ -302,7 +305,7 @@ class GemmaAttention {
|
|||
const size_t kv_offset =
|
||||
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
if (layer_weights_.qkv_einsum_w.data()) {
|
||||
if (layer_weights_.qkv_einsum_w.HasPtr()) {
|
||||
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
|
||||
w_rows_kv_cols, model_dim, x, kv, pool_);
|
||||
} else {
|
||||
|
|
@ -336,8 +339,8 @@ class GemmaAttention {
|
|||
}
|
||||
|
||||
// Apply further processing to K.
|
||||
if (layer_weights_.key_norm_scale.data()) {
|
||||
RMSNormInplace(layer_weights_.key_norm_scale.data(), kv,
|
||||
if (layer_weights_.key_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv,
|
||||
qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
|
||||
|
|
@ -427,8 +430,8 @@ class GemmaAttention {
|
|||
|
||||
// Apply rope and scaling to Q.
|
||||
const size_t pos = queries_pos_[query_idx] + batch_idx;
|
||||
if (layer_weights_.query_norm_scale.data()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.data(), q,
|
||||
if (layer_weights_.query_norm_scale.HasPtr()) {
|
||||
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q,
|
||||
qkv_dim);
|
||||
}
|
||||
PositionalEncodingQK(q, pos, layer_, query_scale);
|
||||
|
|
@ -473,17 +476,18 @@ class GemmaAttention {
|
|||
HWY_DASSERT(layer_config_.model_dim > 0);
|
||||
HWY_DASSERT(layer_config_.heads > 0);
|
||||
HWY_DASSERT(layer_config_.qkv_dim > 0);
|
||||
HWY_DASSERT(layer_weights_.att_weights.data() != nullptr);
|
||||
HWY_DASSERT(layer_weights_.att_weights.HasPtr());
|
||||
HWY_DASSERT(activations_.att_out.All() != nullptr);
|
||||
HWY_DASSERT(activations_.att_sums.All() != nullptr);
|
||||
|
||||
const float* add =
|
||||
layer_weights_.layer_config.softmax_attn_output_biases
|
||||
? layer_weights_.attention_output_biases.data_scale1()
|
||||
? layer_weights_.attention_output_biases.PackedScale1()
|
||||
: nullptr;
|
||||
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
|
||||
ConstMatFromWeights(layer_weights_.att_weights), add,
|
||||
*activations_.env, RowPtrFromBatch(activations_.att_sums));
|
||||
*activations_.env,
|
||||
RowPtrFromBatch(allocator_, activations_.att_sums));
|
||||
}
|
||||
|
||||
public:
|
||||
|
|
@ -533,7 +537,8 @@ class GemmaAttention {
|
|||
layer_weights_(*layer_weights),
|
||||
div_seq_len_(div_seq_len),
|
||||
kv_caches_(kv_caches),
|
||||
pool_(activations.env->parallel.Pools().Pool(0)) {
|
||||
allocator_(activations.env->ctx.allocator),
|
||||
pool_(activations.env->ctx.pools.Pool(0)) {
|
||||
HWY_DASSERT(num_queries_ <= kv_caches_.size());
|
||||
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
|
||||
"query heads must be a multiple of key-value heads");
|
||||
|
|
@ -562,6 +567,7 @@ class GemmaAttention {
|
|||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const hwy::Divisor& div_seq_len_;
|
||||
const KVCaches& kv_caches_;
|
||||
const Allocator2& allocator_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
|
|
@ -606,8 +612,8 @@ class VitAttention {
|
|||
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
|
||||
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
|
||||
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
|
||||
layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env,
|
||||
RowPtrFromBatch(qkv));
|
||||
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
|
||||
RowPtrFromBatch(allocator_, qkv));
|
||||
}
|
||||
|
||||
// TODO(philculliton): transition fully to MatMul.
|
||||
|
|
@ -621,10 +627,10 @@ class VitAttention {
|
|||
|
||||
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
|
||||
RowVectorBatch<float> Q =
|
||||
AllocateAlignedRows<float>(Extents2D(num_tokens_, qkv_dim));
|
||||
AllocateAlignedRows<float>(allocator_, Extents2D(num_tokens_, qkv_dim));
|
||||
RowVectorBatch<float> K =
|
||||
AllocateAlignedRows<float>(Extents2D(seq_len, qkv_dim));
|
||||
RowVectorBatch<float> C(Extents2D(num_tokens_, seq_len));
|
||||
AllocateAlignedRows<float>(allocator_, Extents2D(seq_len, qkv_dim));
|
||||
RowVectorBatch<float> C(allocator_, Extents2D(num_tokens_, seq_len));
|
||||
|
||||
// Initialize att_out to zero prior to head loop.
|
||||
hwy::ZeroBytes(activations_.att_out.All(),
|
||||
|
|
@ -650,7 +656,7 @@ class VitAttention {
|
|||
// this produces C, a (num_tokens_, seq_len) matrix of dot products
|
||||
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
|
||||
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
|
||||
RowPtrFromBatch(C));
|
||||
RowPtrFromBatch(allocator_, C));
|
||||
|
||||
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
|
||||
float* HWY_RESTRICT c = C.Batch(task);
|
||||
|
|
@ -712,13 +718,13 @@ class VitAttention {
|
|||
// head_dim (`qkv_dim`) into output (`att_sums`).
|
||||
HWY_NOINLINE void SumHeads() {
|
||||
PROFILER_ZONE("Gen.VitAttention.SumHeads");
|
||||
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
|
||||
auto* bias = layer_weights_.vit.attn_out_b.PackedScale1();
|
||||
// att_weights and att_out are concatenated heads, each of length
|
||||
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
|
||||
// matmul output is the sum over heads.
|
||||
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
|
||||
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
|
||||
auto att_sums = RowPtrFromBatch(activations_.att_sums);
|
||||
auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums);
|
||||
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
|
||||
}
|
||||
|
||||
|
|
@ -730,7 +736,8 @@ class VitAttention {
|
|||
activations_(activations),
|
||||
layer_weights_(*layer_weights),
|
||||
layer_config_(layer_weights->layer_config),
|
||||
pool_(activations.env->parallel.Pools().Pool(0)) {}
|
||||
allocator_(activations.env->ctx.allocator),
|
||||
pool_(activations.env->ctx.pools.Pool(0)) {}
|
||||
|
||||
HWY_INLINE void operator()() {
|
||||
ComputeQKV();
|
||||
|
|
@ -748,6 +755,7 @@ class VitAttention {
|
|||
Activations& activations_;
|
||||
const LayerWeightsPtrs<T>& layer_weights_;
|
||||
const LayerConfig& layer_config_;
|
||||
const Allocator2& allocator_;
|
||||
hwy::ThreadPool& pool_;
|
||||
};
|
||||
|
||||
|
|
@ -779,32 +787,35 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
|
|||
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr;
|
||||
add_bias ? layer_weights->ffw_gating_biases.PackedScale1() : nullptr;
|
||||
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr;
|
||||
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
||||
auto multiplier = RowPtrFromBatch(activations.C2);
|
||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
||||
const Allocator2& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
|
||||
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||
|
||||
// gating_einsum_w holds two half-matrices. We plan to change the importer to
|
||||
// avoid this confusion by splitting into gating_einsum_w1 and
|
||||
// gating_einsum_w2.
|
||||
const bool split = !!layer_weights->gating_einsum_w.data();
|
||||
const bool split = layer_weights->gating_einsum_w.HasPtr();
|
||||
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
|
||||
auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w,
|
||||
model_dim * ffh_hidden_dim)
|
||||
: ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
||||
decltype(w1) w2;
|
||||
if (split) {
|
||||
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w);
|
||||
w2.ofs = w2.Row(ffh_hidden_dim);
|
||||
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
|
||||
w1.ShrinkRows(ffh_hidden_dim);
|
||||
w2.ShrinkRows(ffh_hidden_dim);
|
||||
} else {
|
||||
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2);
|
||||
}
|
||||
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
|
||||
|
||||
|
|
@ -835,16 +846,17 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
|||
|
||||
const bool add_bias = layer_weights->layer_config.ff_biases;
|
||||
const float* bias1 =
|
||||
add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr;
|
||||
add_bias ? layer_weights->vit.linear_0_b.PackedScale1() : nullptr;
|
||||
const float* output_bias =
|
||||
add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr;
|
||||
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
|
||||
|
||||
// Define slightly more readable names for the weights and activations.
|
||||
const auto x =
|
||||
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
|
||||
|
||||
auto hidden_activations = RowPtrFromBatch(activations.C1);
|
||||
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
|
||||
const Allocator2& allocator = activations.env->ctx.allocator;
|
||||
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
|
||||
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
|
||||
|
||||
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
|
||||
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
|
||||
|
|
@ -853,7 +865,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
|
|||
MatMul(x, w1, bias1, *activations.env, hidden_activations);
|
||||
|
||||
// Activation (Gelu), store in act.
|
||||
RowPtrF multiplier = RowPtrF(nullptr, 0);
|
||||
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
|
||||
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
|
||||
multiplier.Row(0), ff_hidden_dim * num_interleaved);
|
||||
|
||||
|
|
@ -905,11 +917,9 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
|
|||
HWY_DASSERT(token < static_cast<int>(vocab_size));
|
||||
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(
|
||||
df,
|
||||
MakeSpan(weights.embedder_input_embedding.data(), vocab_size * model_dim),
|
||||
DecompressAndZeroPad(df, weights.embedder_input_embedding.Span(),
|
||||
token * model_dim, x.Batch(batch_idx), model_dim);
|
||||
MulByConst(emb_scaling * weights.embedder_input_embedding.scale(),
|
||||
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
|
||||
x.Batch(batch_idx), model_dim);
|
||||
if (weights.weights_config.absolute_pe) {
|
||||
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos);
|
||||
|
|
@ -943,9 +953,10 @@ HWY_NOINLINE void ResidualConnection(
|
|||
template <typename WeightT, typename InOutT>
|
||||
void PostNorm(PostNormType post_norm, size_t num_interleaved,
|
||||
const WeightT& weights, InOutT* inout) {
|
||||
HWY_DASSERT(weights.Rows() == 1);
|
||||
if (post_norm == PostNormType::Scale) {
|
||||
RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout,
|
||||
weights.NumElements());
|
||||
RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout,
|
||||
weights.Cols());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -962,7 +973,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
|||
auto type = layer_weights->layer_config.type;
|
||||
|
||||
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||
layer_weights->pre_attention_norm_scale.data_scale1(),
|
||||
layer_weights->pre_attention_norm_scale.PackedScale1(),
|
||||
activations.pre_att_rms_out.All(), model_dim);
|
||||
|
||||
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
|
||||
|
|
@ -976,7 +987,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
|
|||
activations.x.All(), layer_weights, /*is_attention=*/true);
|
||||
|
||||
RMSNormBatched(num_interleaved, activations.x.All(),
|
||||
layer_weights->pre_ffw_norm_scale.data_scale1(),
|
||||
layer_weights->pre_ffw_norm_scale.PackedScale1(),
|
||||
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
|
||||
if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
|
||||
|
|
@ -1014,8 +1025,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
// y = nn.LayerNorm()(x)
|
||||
// y ~ pre_att_rms_out
|
||||
LayerNormBatched(num_tokens, x.All(),
|
||||
layer_weights->vit.layer_norm_0_scale.data_scale1(),
|
||||
layer_weights->vit.layer_norm_0_bias.data_scale1(),
|
||||
layer_weights->vit.layer_norm_0_scale.PackedScale1(),
|
||||
layer_weights->vit.layer_norm_0_bias.PackedScale1(),
|
||||
activations.pre_att_rms_out.All(), model_dim);
|
||||
|
||||
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
|
||||
|
|
@ -1028,8 +1039,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
|
|||
// y = nn.LayerNorm()(x)
|
||||
// y ~ bf_pre_ffw_rms_out
|
||||
LayerNormBatched(num_tokens, x.All(),
|
||||
layer_weights->vit.layer_norm_1_scale.data_scale1(),
|
||||
layer_weights->vit.layer_norm_1_bias.data_scale1(),
|
||||
layer_weights->vit.layer_norm_1_scale.PackedScale1(),
|
||||
layer_weights->vit.layer_norm_1_bias.PackedScale1(),
|
||||
activations.bf_pre_ffw_rms_out.All(), model_dim);
|
||||
|
||||
// y = out["mlp"] = MlpBlock(...)(y)
|
||||
|
|
@ -1161,8 +1172,8 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
const size_t patch_width = weights.weights_config.vit_config.patch_width;
|
||||
const size_t seq_len = weights.weights_config.vit_config.seq_len;
|
||||
const size_t patch_size = patch_width * patch_width * 3;
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
|
||||
patch_size * model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
|
||||
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
|
||||
HWY_DASSERT(activations.x.Cols() == model_dim);
|
||||
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
|
|
@ -1178,20 +1189,20 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
|
|||
// MatMul(
|
||||
// MatFromBatch(kVitSeqLen, image_patches),
|
||||
// MatFromWeights(weights.vit_img_embedding_kernel),
|
||||
// weights.vit_img_embedding_bias.data_scale1(), *activations.env,
|
||||
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
|
||||
// RowPtrF(activations.x.All(), kVitModelDim));
|
||||
// However, MatMul currently requires that
|
||||
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
|
||||
// which is not the case here. We should relax that requirement on MatMul and
|
||||
// then use the above. For now, we rely on MatVecAdd instead.
|
||||
for (size_t i = 0; i < seq_len; ++i) {
|
||||
MatVecAdd(
|
||||
weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||
image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(),
|
||||
activations.x.Batch(i), activations.env->parallel.Pools().Pool(0));
|
||||
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
|
||||
image_patches[i].get(),
|
||||
weights.vit_img_embedding_bias.PackedScale1(),
|
||||
activations.x.Batch(i), activations.env->ctx.pools.Pool(0));
|
||||
}
|
||||
// Add position embeddings.
|
||||
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
|
||||
AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(),
|
||||
seq_len * model_dim);
|
||||
}
|
||||
|
||||
|
|
@ -1216,23 +1227,23 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
|
|||
}
|
||||
// Final Layernorm.
|
||||
LayerNormBatched(num_tokens, activations.x.All(),
|
||||
weights.vit_encoder_norm_scale.data_scale1(),
|
||||
weights.vit_encoder_norm_bias.data_scale1(),
|
||||
weights.vit_encoder_norm_scale.PackedScale1(),
|
||||
weights.vit_encoder_norm_bias.PackedScale1(),
|
||||
activations.x.All(), vit_model_dim);
|
||||
|
||||
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
|
||||
activations.x = AvgPool4x4(activations.x);
|
||||
|
||||
// Apply soft embedding norm before input projection.
|
||||
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
|
||||
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(),
|
||||
vit_model_dim);
|
||||
}
|
||||
|
||||
// Apply head embedding into image_tokens of size of the LLM kModelDim.
|
||||
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
|
||||
ConstMatFromWeights(weights.vit_img_head_kernel),
|
||||
weights.vit_img_head_bias.data_scale1(), *activations.env,
|
||||
RowPtrFromBatch(image_tokens));
|
||||
weights.vit_img_head_bias.PackedScale1(), *activations.env,
|
||||
RowPtrFromBatch(activations.env->ctx.allocator, image_tokens));
|
||||
}
|
||||
|
||||
// Generates one token for each query. `queries_token` is the previous token
|
||||
|
|
@ -1274,7 +1285,7 @@ HWY_NOINLINE void Transformer(
|
|||
}
|
||||
}
|
||||
|
||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
|
||||
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(),
|
||||
activations.x.All(), model_dim);
|
||||
|
||||
if (activations_observer) {
|
||||
|
|
@ -1374,7 +1385,7 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
|||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||
/*add=*/nullptr, *activations.env,
|
||||
RowPtrFromBatch(activations.logits));
|
||||
RowPtrFromBatch(activations.env->ctx.allocator, activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
|
|
|
|||
|
|
@ -27,22 +27,33 @@
|
|||
#include <utility> // std::move
|
||||
#include <vector>
|
||||
|
||||
#include "compression/io.h" // Path
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tokenizer.h"
|
||||
#include "gemma/weights.h"
|
||||
#include "ops/ops-inl.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "paligemma/image.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Internal init must run before I/O; calling it from `GemmaEnv()` is too late.
|
||||
// This helper function takes care of the internal init plus calling `SetArgs`.
|
||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
|
||||
// Placeholder for internal init, do not modify.
|
||||
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
return MatMulEnv(ThreadingContext2::Get());
|
||||
}
|
||||
|
||||
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
||||
const ModelInfo& info, MatMulEnv& env)
|
||||
: env_(env), tokenizer_(tokenizer_path) {
|
||||
model_.Load(weights, info.model, info.weight, info.wrapping,
|
||||
env_.parallel.Pools().Pool(0),
|
||||
env_.ctx.pools.Pool(0),
|
||||
/*tokenizer_proto=*/nullptr);
|
||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||
}
|
||||
|
|
@ -50,7 +61,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
|
|||
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
|
||||
std::string tokenizer_proto;
|
||||
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
|
||||
env_.parallel.Pools().Pool(0), &tokenizer_proto);
|
||||
env_.ctx.pools.Pool(0), &tokenizer_proto);
|
||||
tokenizer_.Deserialize(tokenizer_proto);
|
||||
chat_template_.Init(tokenizer_, model_.Config().model);
|
||||
}
|
||||
|
|
@ -60,7 +71,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
|
|||
tokenizer_(std::move(tokenizer)),
|
||||
chat_template_(tokenizer_, info.model) {
|
||||
HWY_ASSERT(info.weight == Type::kF32);
|
||||
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
|
||||
model_.Allocate(info.model, info.weight, env_.ctx.pools.Pool(0));
|
||||
}
|
||||
|
||||
Gemma::~Gemma() {
|
||||
|
|
@ -130,12 +141,12 @@ struct GenerateImageTokensT {
|
|||
void Gemma::Generate(const RuntimeConfig& runtime_config,
|
||||
const PromptTokens& prompt, size_t pos, size_t prefix_end,
|
||||
KVCache& kv_cache, TimingInfo& timing_info) {
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateSingleT>(
|
||||
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
|
||||
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
||||
|
|
@ -152,23 +163,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
|
|||
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
|
||||
}
|
||||
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateBatchT>(
|
||||
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
|
||||
kv_caches, &env_, timing_info);
|
||||
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
|
||||
const Image& image, ImageTokens& image_tokens) {
|
||||
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
|
||||
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
|
||||
|
||||
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
|
||||
image_tokens, &env_);
|
||||
|
||||
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
|
||||
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
|
||||
}
|
||||
|
||||
// Non-template functions moved from gemma-inl.h to avoid ODR violations.
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <functional>
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
|
@ -31,8 +33,9 @@
|
|||
#include "gemma/weights.h"
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h" // RowVectorBatch
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/timer.h"
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
|
|
@ -193,6 +196,10 @@ struct TimingInfo {
|
|||
size_t tokens_generated = 0;
|
||||
};
|
||||
|
||||
// Internal init must run before I/O; calling it from GemmaEnv() is too late.
|
||||
// This helper function takes care of the internal init plus calling `SetArgs`.
|
||||
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
|
||||
|
||||
class Gemma {
|
||||
public:
|
||||
// Reads old format weights file and tokenizer file.
|
||||
|
|
@ -206,7 +213,9 @@ class Gemma {
|
|||
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
|
||||
~Gemma();
|
||||
|
||||
MatMulEnv& Env() const { return env_; }
|
||||
const ModelConfig& GetModelConfig() const { return model_.Config(); }
|
||||
// DEPRECATED
|
||||
ModelInfo Info() const {
|
||||
return ModelInfo({.model = model_.Config().model,
|
||||
.wrapping = model_.Config().wrapping,
|
||||
|
|
|
|||
|
|
@ -15,8 +15,8 @@
|
|||
|
||||
// Shared between various frontends.
|
||||
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
|
@ -31,103 +31,10 @@
|
|||
#include "ops/matmul.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h" // HWY_IS_ASAN
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline const char* CompiledConfig() {
|
||||
if (HWY_IS_ASAN) {
|
||||
return "asan";
|
||||
} else if (HWY_IS_MSAN) {
|
||||
return "msan";
|
||||
} else if (HWY_IS_TSAN) {
|
||||
return "tsan";
|
||||
} else if (HWY_IS_HWASAN) {
|
||||
return "hwasan";
|
||||
} else if (HWY_IS_UBSAN) {
|
||||
return "ubsan";
|
||||
} else if (HWY_IS_DEBUG_BUILD) {
|
||||
return "dbg";
|
||||
} else {
|
||||
return "opt";
|
||||
}
|
||||
}
|
||||
|
||||
class AppArgs : public ArgsBase<AppArgs> {
|
||||
public:
|
||||
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
AppArgs() { Init(); };
|
||||
|
||||
int verbosity;
|
||||
|
||||
size_t max_threads; // divided among the detected clusters
|
||||
Tristate pin; // pin threads?
|
||||
Tristate spin; // use spin waits?
|
||||
|
||||
// For BoundedSlice:
|
||||
size_t skip_packages;
|
||||
size_t max_packages;
|
||||
size_t skip_clusters;
|
||||
size_t max_clusters;
|
||||
size_t skip_lps;
|
||||
size_t max_lps;
|
||||
|
||||
std::string eot_line;
|
||||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
|
||||
// The exact meaning is more subtle: see the comment at NestedPools ctor.
|
||||
visitor(max_threads, "num_threads", size_t{0},
|
||||
"Maximum number of threads to use; default 0 = unlimited.", 2);
|
||||
visitor(pin, "pin", Tristate::kDefault,
|
||||
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
visitor(spin, "spin", Tristate::kDefault,
|
||||
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
|
||||
// These can be used to partition CPU sockets/packages and their
|
||||
// clusters/CCXs across several program instances. The default is to use
|
||||
// all available resources.
|
||||
visitor(skip_packages, "skip_packages", size_t{0},
|
||||
"Index of the first socket to use; default 0 = unlimited.", 2);
|
||||
visitor(max_packages, "max_packages", size_t{0},
|
||||
"Maximum number of sockets to use; default 0 = unlimited.", 2);
|
||||
visitor(skip_clusters, "skip_clusters", size_t{0},
|
||||
"Index of the first CCX to use; default 0 = unlimited.", 2);
|
||||
visitor(max_clusters, "max_clusters", size_t{0},
|
||||
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
|
||||
// These are only used when CPU topology is unknown.
|
||||
visitor(skip_lps, "skip_lps", size_t{0},
|
||||
"Index of the first LP to use; default 0 = unlimited.", 2);
|
||||
visitor(max_lps, "max_lps", size_t{0},
|
||||
"Maximum number of LPs to use; default 0 = unlimited.", 2);
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
}
|
||||
};
|
||||
|
||||
static inline BoundedTopology CreateTopology(const AppArgs& app) {
|
||||
return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages),
|
||||
BoundedSlice(app.skip_clusters, app.max_clusters),
|
||||
BoundedSlice(app.skip_lps, app.max_lps));
|
||||
}
|
||||
static inline NestedPools CreatePools(const BoundedTopology& topology,
|
||||
const AppArgs& app) {
|
||||
Allocator::Init(topology);
|
||||
return NestedPools(topology, app.max_threads, app.pin);
|
||||
}
|
||||
|
||||
struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
||||
LoaderArgs(int argc, char* argv[], bool validate = true) {
|
||||
InitAndParse(argc, argv);
|
||||
|
|
@ -154,15 +61,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
|
|||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() {
|
||||
if (!compressed_weights.path.empty()) {
|
||||
if (weights.path.empty()) {
|
||||
weights = compressed_weights;
|
||||
} else {
|
||||
return "Only one of --weights and --compressed_weights can be "
|
||||
"specified. To create compressed weights use the "
|
||||
"compress_weights tool.";
|
||||
}
|
||||
}
|
||||
if (weights.path.empty()) {
|
||||
return "Missing --weights flag, a file for the model weights.";
|
||||
}
|
||||
|
|
@ -250,6 +148,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
|
||||
InferenceArgs() { Init(); };
|
||||
|
||||
int verbosity;
|
||||
|
||||
size_t max_generated_tokens;
|
||||
|
||||
size_t prefill_tbatch_size;
|
||||
|
|
@ -261,6 +161,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
bool multiturn;
|
||||
Path image_file;
|
||||
|
||||
std::string eot_line;
|
||||
|
||||
// Returns error string or nullptr if OK.
|
||||
const char* Validate() const {
|
||||
if (max_generated_tokens > gcpp::kSeqLen) {
|
||||
|
|
@ -272,6 +174,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
template <class Visitor>
|
||||
void ForEach(const Visitor& visitor) {
|
||||
visitor(verbosity, "verbosity", 1,
|
||||
"Show verbose developer information\n 0 = only print generation "
|
||||
"output\n 1 = standard user-facing terminal ui\n 2 = show "
|
||||
"developer/debug info).\n Default = 1.",
|
||||
2);
|
||||
|
||||
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
|
||||
"Maximum number of tokens to generate.");
|
||||
|
||||
|
|
@ -291,6 +199,14 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
" Default : 0 (conversation "
|
||||
"resets every turn)");
|
||||
visitor(image_file, "image_file", Path(), "Image file to load.");
|
||||
|
||||
visitor(
|
||||
eot_line, "eot_line", std::string(""),
|
||||
"End of turn line. "
|
||||
"When you specify this, the prompt will be all lines "
|
||||
"before the line where only the given string appears.\n Default = "
|
||||
"When a newline is encountered, that signals the end of the turn.",
|
||||
2);
|
||||
}
|
||||
|
||||
void CopyTo(RuntimeConfig& runtime_config) const {
|
||||
|
|
@ -317,4 +233,4 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
|
|||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
|
||||
84
gemma/run.cc
84
gemma/run.cc
|
|
@ -15,23 +15,23 @@
|
|||
|
||||
// Command line text interface to gemma.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <iostream>
|
||||
#include <random>
|
||||
#include <string>
|
||||
#include <string_view>
|
||||
#include <vector>
|
||||
|
||||
// Placeholder for internal header, do not modify.
|
||||
#include "compression/shared.h" // PromptWrapping
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/gemma.h" // Gemma
|
||||
#include "gemma/gemma_args.h" // LoaderArgs
|
||||
#include "ops/matmul.h" // MatMulEnv
|
||||
#include "paligemma/image.h"
|
||||
#include "util/app.h"
|
||||
#include "util/args.h" // HasHelp
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/highway.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -78,35 +78,37 @@ std::string GetPrompt(std::istream& input, int verbosity,
|
|||
}
|
||||
|
||||
// The main Read-Eval-Print Loop.
|
||||
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
||||
const InferenceArgs& args, const AcceptFunc& accept_token,
|
||||
std::string& eot_line) {
|
||||
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
|
||||
Gemma& model, KVCache& kv_cache) {
|
||||
PROFILER_ZONE("Gen.misc");
|
||||
size_t abs_pos = 0; // across turns
|
||||
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
|
||||
size_t prompt_size = 0;
|
||||
|
||||
std::mt19937 gen;
|
||||
InitGenerator(args, gen);
|
||||
InitGenerator(inference, gen);
|
||||
|
||||
const bool have_image = !args.image_file.path.empty();
|
||||
const bool have_image = !inference.image_file.path.empty();
|
||||
Image image;
|
||||
ImageTokens image_tokens;
|
||||
if (have_image) {
|
||||
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
|
||||
image_tokens = ImageTokens(Extents2D(
|
||||
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
|
||||
image_tokens =
|
||||
ImageTokens(model.Env().ctx.allocator,
|
||||
Extents2D(model.GetModelConfig().vit_config.seq_len /
|
||||
(pool_dim * pool_dim),
|
||||
model.GetModelConfig().model_dim));
|
||||
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
|
||||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
|
||||
HWY_ASSERT(image.ReadPPM(args.image_file.path));
|
||||
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
image.Resize(image_size, image_size);
|
||||
RuntimeConfig runtime_config = {
|
||||
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = inference.verbosity,
|
||||
.use_spinning = threading.spin};
|
||||
double image_tokens_start = hwy::platform::Now();
|
||||
model.GenerateImageTokens(runtime_config, image, image_tokens);
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
|
||||
fprintf(stderr,
|
||||
"\n\n[ Timing info ] Image token generation took: %d ms\n",
|
||||
|
|
@ -121,12 +123,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
const bool first_response_token = tokens_generated_this_turn == prompt_size;
|
||||
++tokens_generated_this_turn;
|
||||
if (in_prompt) {
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "." << std::flush;
|
||||
}
|
||||
return true;
|
||||
} else if (model.GetModelConfig().IsEOS(token)) {
|
||||
if (app.verbosity >= 2) {
|
||||
if (inference.verbosity >= 2) {
|
||||
std::cout << "\n[ End ]\n";
|
||||
}
|
||||
return true;
|
||||
|
|
@ -135,7 +137,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
|
||||
if (first_response_token) {
|
||||
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cout << "\n\n";
|
||||
}
|
||||
}
|
||||
|
|
@ -147,7 +149,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
tokens_generated_this_turn = 0;
|
||||
|
||||
// Read prompt and handle special commands.
|
||||
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
|
||||
std::string prompt_string =
|
||||
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
|
||||
if (!std::cin) return;
|
||||
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
|
||||
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
|
||||
|
|
@ -163,13 +166,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
}
|
||||
|
||||
// Set up runtime config.
|
||||
TimingInfo timing_info = {.verbosity = app.verbosity};
|
||||
TimingInfo timing_info = {.verbosity = inference.verbosity};
|
||||
RuntimeConfig runtime_config = {.gen = &gen,
|
||||
.verbosity = app.verbosity,
|
||||
.verbosity = inference.verbosity,
|
||||
.stream_token = stream_token,
|
||||
.accept_token = accept_token,
|
||||
.use_spinning = app.spin};
|
||||
args.CopyTo(runtime_config);
|
||||
.use_spinning = threading.spin};
|
||||
inference.CopyTo(runtime_config);
|
||||
size_t prefix_end = 0;
|
||||
|
||||
std::vector<int> prompt;
|
||||
|
|
@ -197,7 +199,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
}
|
||||
|
||||
// Generate until EOS or max_generated_tokens.
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::cerr << "\n[ Reading prompt ] " << std::flush;
|
||||
}
|
||||
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
|
||||
|
|
@ -205,9 +207,10 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
std::cout << "\n\n";
|
||||
|
||||
// Prepare for the next turn. Works only for PaliGemma.
|
||||
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
if (!inference.multiturn ||
|
||||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
|
||||
abs_pos = 0; // Start a new turn at position 0.
|
||||
InitGenerator(args, gen);
|
||||
InitGenerator(inference, gen);
|
||||
} else {
|
||||
// The last token was either EOS, then it should be ignored because it is
|
||||
// never part of the dialog, see Table 5 in the Gemma-2 paper:
|
||||
|
|
@ -223,20 +226,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
|
|||
}
|
||||
}
|
||||
|
||||
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
||||
void Run(ThreadingArgs& threading, LoaderArgs& loader,
|
||||
InferenceArgs& inference) {
|
||||
PROFILER_ZONE("Run.misc");
|
||||
|
||||
// Note that num_threads is an upper bound; we also limit to the number of
|
||||
// detected and enabled cores.
|
||||
const BoundedTopology topology = CreateTopology(app);
|
||||
NestedPools pools = CreatePools(topology, app);
|
||||
MatMulEnv env(topology, pools);
|
||||
if (app.verbosity >= 2) env.print_best = true;
|
||||
MatMulEnv env(MakeMatMulEnv(threading));
|
||||
if (inference.verbosity >= 2) env.print_best = true;
|
||||
Gemma model = CreateGemma(loader, env);
|
||||
KVCache kv_cache =
|
||||
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
|
||||
|
||||
if (app.verbosity >= 1) {
|
||||
if (inference.verbosity >= 1) {
|
||||
std::string instructions =
|
||||
"*Usage*\n"
|
||||
" Enter an instruction and press enter (%C resets conversation, "
|
||||
|
|
@ -259,11 +261,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
|
|||
|
||||
std::cout << "\033[2J\033[1;1H" // clear screen
|
||||
<< kAsciiArtBanner << "\n\n";
|
||||
ShowConfig(loader, inference, app, topology, pools);
|
||||
ShowConfig(threading, loader, inference);
|
||||
std::cout << "\n" << instructions << "\n";
|
||||
}
|
||||
|
||||
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
|
||||
ReplGemma(threading, inference, model, kv_cache);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
@ -272,31 +274,29 @@ int main(int argc, char** argv) {
|
|||
{
|
||||
PROFILER_ZONE("Startup.misc");
|
||||
|
||||
// Placeholder for internal init, do not modify.
|
||||
|
||||
gcpp::ThreadingArgs threading(argc, argv);
|
||||
gcpp::LoaderArgs loader(argc, argv);
|
||||
gcpp::InferenceArgs inference(argc, argv);
|
||||
gcpp::AppArgs app(argc, argv);
|
||||
|
||||
if (gcpp::HasHelp(argc, argv)) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, inference, app);
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
return 0;
|
||||
}
|
||||
|
||||
if (const char* error = loader.Validate()) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, inference, app);
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
if (const char* error = inference.Validate()) {
|
||||
std::cerr << gcpp::kAsciiArtBanner;
|
||||
gcpp::ShowHelp(loader, inference, app);
|
||||
gcpp::ShowHelp(threading, loader, inference);
|
||||
HWY_ABORT("\nInvalid args: %s", error);
|
||||
}
|
||||
|
||||
gcpp::Run(loader, inference, app);
|
||||
gcpp::Run(threading, loader, inference);
|
||||
}
|
||||
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
|
||||
return 0;
|
||||
|
|
|
|||
|
|
@ -562,11 +562,12 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
|
|||
if (llm_layer_idx < 0 && img_layer_idx < 0) {
|
||||
tensors_ = ModelTensors(config);
|
||||
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
|
||||
img_layer_idx < config.vit_config.layer_configs.size()) {
|
||||
img_layer_idx <
|
||||
static_cast<int>(config.vit_config.layer_configs.size())) {
|
||||
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
|
||||
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
|
||||
} else if (0 <= llm_layer_idx &&
|
||||
llm_layer_idx < config.layer_configs.size()) {
|
||||
llm_layer_idx < static_cast<int>(config.layer_configs.size())) {
|
||||
const auto& layer_config = config.layer_configs[llm_layer_idx];
|
||||
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,6 +54,28 @@ struct TensorInfo {
|
|||
bool cols_take_extra_dims = false;
|
||||
};
|
||||
|
||||
// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for
|
||||
// not-present tensors such as ViT in a text-only model.
|
||||
static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
|
||||
if (tensor == nullptr) return Extents2D(0, 0);
|
||||
|
||||
size_t cols = tensor->shape.back();
|
||||
size_t rows = 1;
|
||||
if (tensor->cols_take_extra_dims) {
|
||||
rows = tensor->shape[0];
|
||||
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
|
||||
cols *= tensor->shape[i];
|
||||
}
|
||||
} else { // rows take extra dims
|
||||
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
|
||||
rows *= tensor->shape[i];
|
||||
}
|
||||
}
|
||||
// Sometimes only one of rows or cols is zero; set both for consistency.
|
||||
if (rows == 0 || cols == 0) rows = cols = 0;
|
||||
return Extents2D(rows, cols);
|
||||
}
|
||||
|
||||
// Universal index of tensor information, which can be built for a specific
|
||||
// layer_idx.
|
||||
class TensorIndex {
|
||||
|
|
@ -96,6 +118,16 @@ class TensorIndex {
|
|||
std::unordered_map<std::string, size_t> name_map_;
|
||||
};
|
||||
|
||||
static inline TensorIndex TensorIndexLLM(const ModelConfig& config,
|
||||
size_t llm_layer_idx) {
|
||||
return TensorIndex(config, static_cast<int>(llm_layer_idx), -1, false);
|
||||
}
|
||||
|
||||
static inline TensorIndex TensorIndexImg(const ModelConfig& config,
|
||||
size_t img_layer_idx) {
|
||||
return TensorIndex(config, -1, static_cast<int>(img_layer_idx), false);
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@
|
|||
#include "compression/shared.h"
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -118,7 +119,7 @@ struct TensorSaver {
|
|||
weights.ForEachTensor(
|
||||
{&weights}, fet,
|
||||
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[0]->CallUpcasted(writer, name);
|
||||
CallUpcasted(tensors[0]->GetType(), tensors[0], writer, name);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
@ -155,11 +156,11 @@ class WeightInitializer {
|
|||
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
|
||||
|
||||
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
float* data = tensors[0]->data<float>();
|
||||
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
|
||||
float* data = tensors[0]->RowT<float>(0);
|
||||
for (size_t i = 0; i < tensors[0]->Extents().Area(); ++i) {
|
||||
data[i] = dist_(gen_);
|
||||
}
|
||||
tensors[0]->set_scale(1.0f);
|
||||
tensors[0]->SetScale(1.0f);
|
||||
}
|
||||
|
||||
private:
|
||||
|
|
@ -226,11 +227,11 @@ void ModelWeightsStorage::LogWeightStats() {
|
|||
{float_weights_.get()}, ForEachType::kInitNoToc,
|
||||
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
const MatPtr& tensor = *tensors[0];
|
||||
if (tensor.scale() != 1.0f) {
|
||||
printf("[scale=%f] ", tensor.scale());
|
||||
if (tensor.Scale() != 1.0f) {
|
||||
printf("[scale=%f] ", tensor.Scale());
|
||||
}
|
||||
LogVec(name, tensor.data<float>(), tensor.NumElements());
|
||||
total_weights += tensor.NumElements();
|
||||
LogVec(name, tensor.RowT<float>(0), tensor.Extents().Area());
|
||||
total_weights += tensor.Extents().Area();
|
||||
});
|
||||
printf("%-20s %12zu\n", "Total", total_weights);
|
||||
}
|
||||
|
|
@ -258,8 +259,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
|
|||
}
|
||||
|
||||
template <>
|
||||
void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||
void LayerWeightsPtrs<NuqStream>::Reshape(MatOwner* storage) {
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
|
|
@ -267,8 +268,7 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
|||
|
||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||
if (storage != nullptr) {
|
||||
storage->Allocate();
|
||||
att_weights.SetPtr(*storage);
|
||||
storage->AllocateFor(att_weights, MatPadding::kPacked);
|
||||
}
|
||||
|
||||
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
|
||||
|
|
@ -279,7 +279,7 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
|||
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
|
||||
|
||||
HWY_NAMESPACE::DecompressAndZeroPad(
|
||||
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0,
|
||||
df, MakeSpan(attn_vec_einsum_w.Packed(), model_dim * heads * qkv_dim), 0,
|
||||
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
|
||||
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
|
|
@ -296,10 +296,10 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
|
|||
|
||||
HWY_NAMESPACE::Compress(
|
||||
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
|
||||
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
|
||||
MakeSpan(att_weights.Packed(), model_dim * heads * qkv_dim),
|
||||
/*packed_ofs=*/0, pool);
|
||||
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -31,12 +31,32 @@
|
|||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline std::string CacheName(const MatPtr& mat, int layer = -1,
|
||||
char separator = ' ', int index = -1) {
|
||||
// Already used/retired: s, S, n, 1
|
||||
const char prefix = mat.GetType() == Type::kF32 ? 'F'
|
||||
: mat.GetType() == Type::kBF16 ? 'B'
|
||||
: mat.GetType() == Type::kSFP ? '$'
|
||||
: mat.GetType() == Type::kNUQ ? '2'
|
||||
: '?';
|
||||
std::string name = std::string(1, prefix) + mat.Name();
|
||||
if (layer >= 0 || index >= 0) {
|
||||
name += '_';
|
||||
if (layer >= 0) name += std::to_string(layer);
|
||||
if (index >= 0) {
|
||||
name += separator + std::to_string(index);
|
||||
}
|
||||
}
|
||||
return name;
|
||||
}
|
||||
|
||||
// Different tensors need to appear in a ForEachTensor, according to what is
|
||||
// happening.
|
||||
enum class ForEachType {
|
||||
|
|
@ -181,10 +201,10 @@ struct LayerWeightsPtrs {
|
|||
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
|
||||
// after loading weights via ForEachTensor.
|
||||
// TODO: update compression/convert_weights to bake this in.
|
||||
void Reshape(MatStorage* storage) {
|
||||
void Reshape(MatOwner* storage) {
|
||||
static_assert(!hwy::IsSame<Weight, NuqStream>());
|
||||
|
||||
if (attn_vec_einsum_w.data() == nullptr) return;
|
||||
if (!attn_vec_einsum_w.HasPtr()) return;
|
||||
|
||||
const size_t model_dim = layer_config.model_dim;
|
||||
const size_t heads = layer_config.heads;
|
||||
|
|
@ -192,18 +212,18 @@ struct LayerWeightsPtrs {
|
|||
|
||||
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
|
||||
if (storage != nullptr) {
|
||||
storage->Allocate();
|
||||
att_weights.SetPtr(*storage);
|
||||
storage->AllocateFor(att_weights, MatPadding::kPacked);
|
||||
}
|
||||
for (size_t m = 0; m < model_dim; ++m) {
|
||||
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
|
||||
Weight* HWY_RESTRICT out_row =
|
||||
att_weights.template RowT<Weight>(0) + m * heads * qkv_dim;
|
||||
for (size_t h = 0; h < heads; ++h) {
|
||||
hwy::CopyBytes(
|
||||
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
|
||||
hwy::CopyBytes(attn_vec_einsum_w.template RowT<Weight>(0) +
|
||||
h * model_dim * qkv_dim + m * qkv_dim,
|
||||
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
|
||||
}
|
||||
}
|
||||
att_weights.set_scale(attn_vec_einsum_w.scale());
|
||||
att_weights.SetScale(attn_vec_einsum_w.Scale());
|
||||
}
|
||||
|
||||
ArrayT<WeightF32OrBF16> key_norm_scale;
|
||||
|
|
@ -215,8 +235,8 @@ struct LayerWeightsPtrs {
|
|||
for (int i = 0; i < ptrs.size(); ++i) { \
|
||||
tensors[i] = &ptrs[i]->member; \
|
||||
} \
|
||||
if (tensors[0]->Ptr() != nullptr || fet != ForEachType::kIgnoreNulls) { \
|
||||
func(ptrs[0]->member.CacheName(layer_idx, sep, sep_index).c_str(), \
|
||||
if (tensors[0]->HasPtr() || fet != ForEachType::kIgnoreNulls) { \
|
||||
func(CacheName(ptrs[0]->member, layer_idx, sep, sep_index).c_str(), \
|
||||
hwy::Span<MatPtr*>(tensors.data(), ptrs.size())); \
|
||||
} \
|
||||
}
|
||||
|
|
@ -307,19 +327,18 @@ struct LayerWeightsPtrs {
|
|||
void ZeroInit(int layer_idx) {
|
||||
ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls,
|
||||
[](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
tensors[0]->ZeroInit();
|
||||
gcpp::ZeroInit(*tensors[0]);
|
||||
});
|
||||
}
|
||||
|
||||
// Allocates memory for all the tensors in the layer.
|
||||
// Note that this is slow and only used for a stand-alone layer.
|
||||
void Allocate(std::vector<MatStorage>& layer_storage) {
|
||||
void Allocate(std::vector<MatOwner>& layer_storage) {
|
||||
ForEachTensor(
|
||||
{this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
|
||||
[&layer_storage](const char* name, hwy::Span<MatPtr*> tensors) {
|
||||
layer_storage.emplace_back(*tensors[0]);
|
||||
layer_storage.back().Allocate();
|
||||
tensors[0]->SetPtr(layer_storage.back());
|
||||
layer_storage.push_back(MatOwner());
|
||||
layer_storage.back().AllocateFor(*tensors[0], MatPadding::kPacked);
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
@ -393,11 +412,9 @@ struct ModelWeightsPtrs {
|
|||
|
||||
// Called by weights.cc after Loading, before att_w has been allocated.
|
||||
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool,
|
||||
std::vector<MatStorage>& model_storage) {
|
||||
std::vector<MatOwner>& model_storage) {
|
||||
size_t storage_index = model_storage.size();
|
||||
for (auto& layer : c_layers) {
|
||||
model_storage.emplace_back(layer.att_weights);
|
||||
}
|
||||
model_storage.resize(model_storage.size() + c_layers.size());
|
||||
pool.Run(0, c_layers.size(),
|
||||
[this, &model_storage, storage_index](uint64_t layer,
|
||||
size_t /*thread*/) {
|
||||
|
|
@ -412,8 +429,8 @@ struct ModelWeightsPtrs {
|
|||
}
|
||||
|
||||
void ZeroInit() {
|
||||
embedder_input_embedding.ZeroInit();
|
||||
final_norm_scale.ZeroInit();
|
||||
gcpp::ZeroInit(embedder_input_embedding);
|
||||
gcpp::ZeroInit(final_norm_scale);
|
||||
for (size_t i = 0; i < c_layers.size(); ++i) {
|
||||
c_layers[i].ZeroInit(i);
|
||||
}
|
||||
|
|
@ -430,21 +447,21 @@ struct ModelWeightsPtrs {
|
|||
return &vit_layers[layer];
|
||||
}
|
||||
|
||||
void Allocate(std::vector<MatStorage>& model_storage, hwy::ThreadPool& pool) {
|
||||
void Allocate(std::vector<MatOwner>& model_storage, hwy::ThreadPool& pool) {
|
||||
std::vector<MatPtr*> model_toc;
|
||||
ForEachTensor(
|
||||
{this}, ForEachType::kInitNoToc,
|
||||
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
model_toc.push_back(tensors[0]);
|
||||
model_storage.emplace_back(*tensors[0]);
|
||||
model_storage.push_back(MatOwner());
|
||||
});
|
||||
// Allocate in parallel using the pool.
|
||||
pool.Run(0, model_toc.size(),
|
||||
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
|
||||
// model_storage may have had content before we started.
|
||||
size_t idx = task + model_storage.size() - model_toc.size();
|
||||
model_storage[idx].Allocate();
|
||||
model_toc[task]->SetPtr(model_storage[idx]);
|
||||
model_storage[idx].AllocateFor(*model_toc[task],
|
||||
MatPadding::kPacked);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -453,8 +470,7 @@ struct ModelWeightsPtrs {
|
|||
ForEachTensor({this, const_cast<ModelWeightsPtrs<Weight>*>(&other)},
|
||||
ForEachType::kIgnoreNulls,
|
||||
[](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(),
|
||||
tensors[1]->SizeBytes());
|
||||
CopyMat(*tensors[1], *tensors[0]);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
@ -467,10 +483,10 @@ struct ModelWeightsPtrs {
|
|||
[&scales, &scale_pos, this](const char*, hwy::Span<MatPtr*> tensors) {
|
||||
if (this->scale_names.count(tensors[0]->Name())) {
|
||||
if (scale_pos < scales.size()) {
|
||||
tensors[0]->set_scale(scales[scale_pos]);
|
||||
tensors[0]->SetScale(scales[scale_pos]);
|
||||
} else {
|
||||
float scale = ScaleWeights(tensors[0]->data<float>(),
|
||||
tensors[0]->NumElements());
|
||||
float scale = ScaleWeights(tensors[0]->RowT<float>(0),
|
||||
tensors[0]->Extents().Area());
|
||||
scales.push_back(scale);
|
||||
}
|
||||
++scale_pos;
|
||||
|
|
@ -615,7 +631,7 @@ class ModelWeightsStorage {
|
|||
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
|
||||
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
|
||||
// Storage for all the matrices and vectors.
|
||||
std::vector<MatStorage> model_storage_;
|
||||
std::vector<MatOwner> model_storage_;
|
||||
};
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
|
|
@ -31,15 +31,12 @@
|
|||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/nanobenchmark.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -53,8 +50,8 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "compression/test_util-inl.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -63,59 +60,6 @@ extern int64_t first_target;
|
|||
|
||||
namespace HWY_NAMESPACE {
|
||||
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
template <typename MatT>
|
||||
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
HWY_ASSERT(content);
|
||||
const float scale =
|
||||
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
content[r * extents.cols + c] = f;
|
||||
}
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
// extents describes the transposed matrix.
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
const float scale =
|
||||
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
content[r * extents.cols + c] = f;
|
||||
}
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
// Arbitrary value, different from 1, must match GenerateMat.
|
||||
mat->set_scale(0.6f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
||||
std::vector<double>& times, MMPerKey* per_key) {
|
||||
std::sort(times.begin(), times.end());
|
||||
|
|
@ -135,7 +79,8 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
|
|||
// M = A rows, K = A cols, N = C cols.
|
||||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
|
||||
if (env.print_config || env.print_measurement) {
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
|
@ -147,24 +92,23 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
const Extents2D B_extents(N, K); // already transposed
|
||||
const Extents2D C_extents(M, N);
|
||||
|
||||
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
|
||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
|
||||
RowVectorBatch<TC> c_slow_batch =
|
||||
AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
|
||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
|
||||
if (add) {
|
||||
add_storage = GenerateMat<float>(Extents2D(1, N), pool);
|
||||
HWY_ASSERT(add_storage);
|
||||
add_storage->set_scale(1.0f);
|
||||
add_storage.SetScale(1.0f);
|
||||
}
|
||||
|
||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
const auto A = ConstMatFromWeights(*a);
|
||||
const auto B = ConstMatFromWeights(*b_trans);
|
||||
MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
const auto A = ConstMatFromWeights(a);
|
||||
const auto B = ConstMatFromWeights(b_trans);
|
||||
|
||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
|
||||
|
||||
// Fewer reps for large batch sizes, which take longer.
|
||||
const size_t num_samples = M < 32 ? 20 : 12;
|
||||
|
|
@ -174,11 +118,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
// Ensure usage conditions are set before autotuning. Both binding and
|
||||
// spinning may materially affect the choice of config. No harm in calling
|
||||
// BindB/C if there is a single package: they will be a no-op.
|
||||
BindB(B_extents.rows, sizeof(TC), B, env.parallel);
|
||||
BindC(A_extents.rows, C, env.parallel);
|
||||
BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel);
|
||||
BindC(allocator, A_extents.rows, C, env.parallel);
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
env.parallel.Pools().MaybeStartSpinning(use_spinning);
|
||||
env.ctx.pools.MaybeStartSpinning(use_spinning);
|
||||
|
||||
// env.print_config = true;
|
||||
// env.print_measurement = true;
|
||||
|
|
@ -198,7 +142,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
|
|||
if (per_key->autotune.Best()) times.push_back(elapsed);
|
||||
}
|
||||
hwy::PreventElision(keep);
|
||||
env.parallel.Pools().MaybeStopSpinning(use_spinning);
|
||||
env.ctx.pools.MaybeStopSpinning(use_spinning);
|
||||
PrintSpeed(A_extents, B_extents, times, per_key);
|
||||
}
|
||||
|
||||
|
|
@ -216,17 +160,11 @@ void BenchAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
const size_t max_threads = 0; // no limit
|
||||
const BoundedSlice package_slice; // all packages/sockets
|
||||
const BoundedSlice cluster_slice; // all clusters/CCX
|
||||
const BoundedSlice lp_slice; // default to all cores (per package).
|
||||
const BoundedTopology topology(package_slice, cluster_slice, lp_slice);
|
||||
Allocator::Init(topology, /*enable_bind=*/true);
|
||||
NestedPools pools(topology, max_threads, Tristate::kDefault);
|
||||
fprintf(stderr, "BenchAllMatMul %s %s\n", topology.TopologyString(),
|
||||
pools.PinString());
|
||||
ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
|
||||
ctx.pools.PinString());
|
||||
|
||||
MatMulEnv env(topology, pools);
|
||||
MatMulEnv env(ctx);
|
||||
|
||||
for (size_t batch_size : {1, 4, 128, 512}) {
|
||||
constexpr bool kAdd = false;
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@
|
|||
#include <stddef.h>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
|
|
@ -379,10 +380,7 @@ template <typename MatT, typename VT>
|
|||
HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs,
|
||||
const VT* vec_aligned, size_t num) {
|
||||
const hn::ScalableTag<VT> d;
|
||||
return w.scale() * Dot(d,
|
||||
MakeConstSpan(reinterpret_cast<const MatT*>(w.Ptr()),
|
||||
w.NumElements()),
|
||||
w_ofs, vec_aligned, num);
|
||||
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -28,9 +28,8 @@
|
|||
|
||||
#include "compression/shared.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/app.h"
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
|
@ -805,7 +804,7 @@ class DotStats {
|
|||
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
|
||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f);
|
||||
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
|
||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4);
|
||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.8E-4);
|
||||
|
||||
ASSERT_INSIDE(kPairwise, 4.5E-4, s_l1s[kPairwise].Mean(), 4E-3);
|
||||
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f);
|
||||
|
|
@ -1000,9 +999,7 @@ struct TestShortDotsT {
|
|||
const size_t N = hn::Lanes(d);
|
||||
const hn::ScalableTag<float> df; // for CallDot
|
||||
|
||||
const AppArgs app;
|
||||
BoundedTopology topology(CreateTopology(app));
|
||||
NestedPools pools = CreatePools(topology, app);
|
||||
const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator;
|
||||
CompressWorkingSet work;
|
||||
std::mt19937 rng;
|
||||
rng.seed(12345);
|
||||
|
|
@ -1014,14 +1011,14 @@ struct TestShortDotsT {
|
|||
// hence they require padding to one vector.
|
||||
const size_t padded_num = hwy::RoundUpTo(num, N);
|
||||
const size_t packed_num = CompressedArrayElements<Packed>(num);
|
||||
RowVectorBatch<float> raw_w(Extents2D(1, padded_num));
|
||||
RowVectorBatch<float> raw_v(Extents2D(1, padded_num));
|
||||
RowVectorBatch<Packed> weights(Extents2D(1, packed_num));
|
||||
RowVectorBatch<float> raw_w(allocator, Extents2D(1, padded_num));
|
||||
RowVectorBatch<float> raw_v(allocator, Extents2D(1, padded_num));
|
||||
RowVectorBatch<Packed> weights(allocator, Extents2D(1, packed_num));
|
||||
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
|
||||
RowVectorBatch<T> vectors(Extents2D(1, num));
|
||||
RowVectorBatch<T> vectors(allocator, Extents2D(1, num));
|
||||
const PackedSpan<T> v(vectors.Batch(0), num);
|
||||
|
||||
RowVectorBatch<double> bufs(Extents2D(1, num));
|
||||
RowVectorBatch<double> bufs(allocator, Extents2D(1, num));
|
||||
double* HWY_RESTRICT buf = bufs.Batch(0);
|
||||
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
|
|
@ -1099,10 +1096,21 @@ void TestAllDot() {
|
|||
return;
|
||||
}
|
||||
|
||||
constexpr size_t kMaxWorkers = 15;
|
||||
|
||||
// Reset with cap on workers because we only support `kMaxWorkers`.
|
||||
ThreadingContext2::ThreadHostileInvalidate();
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.max_packages = 1;
|
||||
threading_args.max_clusters = 1;
|
||||
threading_args.max_lps = kMaxWorkers - 1;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
ThreadingContext2& ctx = ThreadingContext2::Get();
|
||||
const Allocator2& allocator = ctx.allocator;
|
||||
|
||||
{ // ensure no profiler zones are active
|
||||
const hn::ScalableTag<float> df;
|
||||
|
||||
constexpr size_t kMaxWorkers = 15;
|
||||
std::mt19937 rngs[kMaxWorkers];
|
||||
for (size_t i = 0; i < kMaxWorkers; ++i) {
|
||||
rngs[i].seed(12345 + 65537 * i);
|
||||
|
|
@ -1110,15 +1118,13 @@ void TestAllDot() {
|
|||
|
||||
constexpr size_t kReps = hn::AdjustedReps(40);
|
||||
const size_t num = 24 * 1024;
|
||||
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1),
|
||||
BoundedSlice());
|
||||
NestedPools pools(topology, kMaxWorkers - 1, /*pin=*/Tristate::kDefault);
|
||||
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<float> a(allocator, Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<float> b(allocator, Extents2D(kMaxWorkers, num));
|
||||
RowVectorBatch<double> bufs(allocator, Extents2D(kMaxWorkers, num));
|
||||
std::array<DotStats, kMaxWorkers> all_stats;
|
||||
|
||||
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
ctx.pools.Cluster(0, 0).Run(
|
||||
0, kReps, [&](const uint32_t rep, size_t thread) {
|
||||
float* HWY_RESTRICT pa = a.Batch(thread);
|
||||
float* HWY_RESTRICT pb = b.Batch(thread);
|
||||
double* HWY_RESTRICT buf = bufs.Batch(thread);
|
||||
|
|
@ -1136,7 +1142,8 @@ void TestAllDot() {
|
|||
std::array<double, kTimeReps> elapsed;
|
||||
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
|
||||
const double start = hwy::platform::Now();
|
||||
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||
dots[variant] +=
|
||||
CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
|
||||
hwy::PreventElision(*pa);
|
||||
elapsed[time_rep] = hwy::platform::Now() - start;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@
|
|||
#include <cmath> // std::abs
|
||||
#include <memory>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -37,6 +37,7 @@
|
|||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "ops/matvec-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
|
|
@ -48,18 +49,18 @@ using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
|||
|
||||
FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
|
||||
const FloatPtr& add) {
|
||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(mat.NumElements());
|
||||
const size_t num = mat.Rows() * mat.Cols();
|
||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(num);
|
||||
FloatPtr out = hwy::AllocateAligned<float>(mat.Rows());
|
||||
HWY_ASSERT(raw_mat && out);
|
||||
const hn::ScalableTag<float> df;
|
||||
DecompressAndZeroPad(df, MakeSpan(mat.data(), mat.NumElements()), 0,
|
||||
raw_mat.get(), mat.NumElements());
|
||||
DecompressAndZeroPad(df, mat.Span(), 0, raw_mat.get(), num);
|
||||
for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) {
|
||||
out[idx_row] = 0.0f;
|
||||
for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) {
|
||||
out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col];
|
||||
}
|
||||
out[idx_row] *= mat.scale();
|
||||
out[idx_row] *= mat.Scale();
|
||||
out[idx_row] += add[idx_row];
|
||||
}
|
||||
return out;
|
||||
|
|
@ -69,8 +70,10 @@ template <typename MatT, size_t kOuter, size_t kInner>
|
|||
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat = std::make_unique<MatStorageT<float>>("TestMat", kOuter, kInner);
|
||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
const Extents2D extents(kOuter, kInner);
|
||||
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents,
|
||||
MatPadding::kPacked);
|
||||
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
|
||||
HWY_ASSERT(raw_mat);
|
||||
const float scale = 1.0f / kInner;
|
||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||
|
|
@ -80,8 +83,8 @@ std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
|
|||
}
|
||||
});
|
||||
|
||||
CompressScaled(raw_mat.get(), mat->NumElements(), ws, *mat, pool);
|
||||
mat->set_scale(1.9f); // Arbitrary value, different from 1.
|
||||
CompressScaled(raw_mat.get(), extents.Area(), ws, *mat, pool);
|
||||
mat->SetScale(1.9f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -22,9 +23,8 @@
|
|||
#include "ops/matmul.h" // IWYU pragma: export
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/timer.h"
|
||||
|
||||
|
|
@ -866,6 +866,8 @@ class MMPerPackage {
|
|||
const IndexRange& range_np)
|
||||
: args_(args),
|
||||
pkg_idx_(pkg_idx),
|
||||
// May be overwritten with a view of A, if already BF16.
|
||||
A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())),
|
||||
range_np_(range_np),
|
||||
mr_(config.MR()),
|
||||
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
|
||||
|
|
@ -873,15 +875,12 @@ class MMPerPackage {
|
|||
ranges_nc_(config.RangesOfNC(range_np)),
|
||||
order_(config.Order()),
|
||||
inner_tasks_(config.InnerTasks()),
|
||||
out_(config.Out()) {
|
||||
// May be overwritten with a view of A, if already BF16.
|
||||
A_ = args_.env->storage.A(pkg_idx, A.Extents());
|
||||
{
|
||||
out_(config.Out()),
|
||||
line_bytes_(args.env->ctx.allocator.LineBytes()) {
|
||||
MMZone zone;
|
||||
zone.MaybeEnter("MM.DecompressA", args_);
|
||||
A_ = DecompressA(A);
|
||||
}
|
||||
}
|
||||
|
||||
// B is decompressed several call layers lower, but not all member functions
|
||||
// depend on TB, so pass it as an argument instead of templating the class.
|
||||
|
|
@ -909,14 +908,14 @@ class MMPerPackage {
|
|||
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
|
||||
// allocation avoids passing a worker index.
|
||||
static constexpr size_t B_stride_max_ =
|
||||
StrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
|
||||
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
|
||||
static constexpr size_t B_storage_max_ =
|
||||
kNR * B_stride_max_ + Allocator::MaxQuantumBytes() / sizeof(BF16);
|
||||
kNR * B_stride_max_ + Allocator2::MaxQuantum<BF16>();
|
||||
|
||||
// Granularity of `ForNP`. B rows produce C columns, so we
|
||||
// want a multiple of the line size to prevent false sharing.
|
||||
static size_t MultipleNP(size_t sizeof_TC) {
|
||||
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC);
|
||||
size_t MultipleNP(size_t sizeof_TC) const {
|
||||
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
|
||||
}
|
||||
|
||||
// Single M and K, parallel N. Fills all of C directly.
|
||||
|
|
@ -931,14 +930,16 @@ class MMPerPackage {
|
|||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
|
||||
const size_t B_stride = StrideForCyclicOffsets<BF16>(K);
|
||||
const size_t B_stride =
|
||||
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
|
||||
|
||||
// Similar to `loop_nc` below, but here we hoisted `A_view`.
|
||||
args_.env->parallel.ForNP(
|
||||
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
|
||||
[&](const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
|
||||
B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
|
|
@ -972,7 +973,9 @@ class MMPerPackage {
|
|||
auto out_tag) HWY_ATTR {
|
||||
const size_t kc = range_kc.Num();
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
|
||||
const RowPtrBF B_view(B_storage, kc, StrideForCyclicOffsets<BF16>(kc));
|
||||
const RowPtrBF B_view(
|
||||
args_.env->ctx.allocator, B_storage, kc,
|
||||
StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum<BF16>()));
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
|
|
@ -1027,7 +1030,8 @@ class MMPerPackage {
|
|||
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
|
||||
const IndexRange& range_K = ranges_kc_.Range(0);
|
||||
const size_t K = range_K.Num();
|
||||
const size_t B_stride = StrideForCyclicOffsets<BF16>(K);
|
||||
const size_t B_stride =
|
||||
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
|
||||
|
||||
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
|
||||
// except for the profiler strings and `out_tag`.
|
||||
|
|
@ -1036,7 +1040,8 @@ class MMPerPackage {
|
|||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(B_storage, K, B_stride);
|
||||
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
|
||||
B_stride);
|
||||
|
||||
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
|
||||
row_b += kNR) {
|
||||
|
|
@ -1062,7 +1067,8 @@ class MMPerPackage {
|
|||
zone.MaybeEnter("MM.NT_MT_K", args_);
|
||||
const size_t kc_max = ranges_kc_.TaskSize();
|
||||
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
|
||||
const size_t B_stride = StrideForCyclicOffsets<BF16>(kc_max);
|
||||
const size_t B_stride = StrideForCyclicOffsets(
|
||||
kc_max, args_.env->ctx.allocator.Quantum<BF16>());
|
||||
// Sequential loop over NC/MC/KC, for when the M/N loops are
|
||||
// already parallel. This is B3A2C0 in MOMMS terminology: we read
|
||||
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
|
||||
|
|
@ -1088,7 +1094,8 @@ class MMPerPackage {
|
|||
ranges_mc_, ranges_nc_, pkg_idx_,
|
||||
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
|
||||
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
|
||||
const RowPtrBF B_view(B_storage, kc_max, B_stride);
|
||||
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max,
|
||||
B_stride);
|
||||
|
||||
// Peel off the first iteration of the kc loop: avoid
|
||||
// zero-initializing `partial` by writing into it.
|
||||
|
|
@ -1151,8 +1158,7 @@ class MMPerPackage {
|
|||
// At least one vector, otherwise DecompressAndZeroPad will add
|
||||
// padding, which might overwrite neighboring tasks. Also a whole cache
|
||||
// line to avoid false sharing.
|
||||
const size_t multiple_K =
|
||||
HWY_MAX(NBF, Allocator::LineBytes() / sizeof(BF16));
|
||||
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
|
||||
|
||||
args_.env->parallel.ForNP(
|
||||
all_K, multiple_K, inner_tasks, pkg_idx_,
|
||||
|
|
@ -1170,6 +1176,7 @@ class MMPerPackage {
|
|||
// Autotuning wrapper for `DoDecompressA`.
|
||||
template <typename TA>
|
||||
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
|
||||
const Allocator2& allocator = args_.env->ctx.allocator;
|
||||
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
|
||||
// If already BF16, maybe return a view:
|
||||
if constexpr (hwy::IsSame<TA, BF16>()) {
|
||||
|
|
@ -1177,7 +1184,8 @@ class MMPerPackage {
|
|||
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
|
||||
if (HWY_LIKELY(A.extents.cols % NBF == 0)) {
|
||||
const BF16* pos = A.ptr + A.Row(0);
|
||||
return RowPtrBF(const_cast<BF16*>(pos), A.extents.cols, A.Stride());
|
||||
return RowPtrBF(allocator, const_cast<BF16*>(pos), A.extents.cols,
|
||||
A.Stride());
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1251,6 +1259,7 @@ class MMPerPackage {
|
|||
const MMOrder order_;
|
||||
const size_t inner_tasks_;
|
||||
const MMOut out_;
|
||||
const size_t line_bytes_;
|
||||
}; // MMPerPackage
|
||||
|
||||
// Stateless, wraps member functions.
|
||||
|
|
@ -1308,6 +1317,7 @@ template <typename TA, typename TB, typename TC>
|
|||
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||
const RowPtr<TC>& C) {
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
const size_t M = A.Extents().rows;
|
||||
const size_t K = A.Extents().cols;
|
||||
const size_t N = B.Extents().rows;
|
||||
|
|
@ -1315,11 +1325,11 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
|||
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
|
||||
// First time we see this shape/key.
|
||||
if (HWY_UNLIKELY(index < 0)) {
|
||||
env.keys.Append(key);
|
||||
env.keys.Append(key, allocator);
|
||||
|
||||
size_t max_packages = MMParallel::kMaxPackages;
|
||||
// For low-batch, multiple sockets only help if binding is enabled.
|
||||
if (!Allocator::ShouldBind() && M <= 4) {
|
||||
if (!allocator.ShouldBind() && M <= 4) {
|
||||
max_packages = 1;
|
||||
}
|
||||
|
||||
|
|
@ -1351,8 +1361,9 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
|||
HWY_ASSERT(N % kNR == 0);
|
||||
|
||||
// Negligible CPU time.
|
||||
tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR,
|
||||
per_key.ranges_np, env.print_config));
|
||||
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC),
|
||||
MMKernel::kMaxMR, kNR, per_key.ranges_np,
|
||||
env.print_config));
|
||||
}
|
||||
|
||||
const MMConfig& cfg = tuner.NextConfig();
|
||||
|
|
|
|||
|
|
@ -60,10 +60,11 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
|
|||
// and holds most of their arguments in member variables.
|
||||
class GenerateCandidates {
|
||||
public:
|
||||
GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np, bool print_config)
|
||||
: M_(M),
|
||||
: allocator_(allocator),
|
||||
M_(M),
|
||||
K_(K),
|
||||
N_(N),
|
||||
sizeof_TC_(sizeof_TC),
|
||||
|
|
@ -73,8 +74,8 @@ class GenerateCandidates {
|
|||
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
|
||||
// is likely still in L1, but we expect K > 1000 and might as well round
|
||||
// up to the line size.
|
||||
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
|
||||
nc_multiple_(Allocator::StepBytes() / sizeof_TC),
|
||||
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
|
||||
nc_multiple_(allocator.StepBytes() / sizeof_TC),
|
||||
ranges_np_(ranges_np),
|
||||
print_config_(print_config) {}
|
||||
|
||||
|
|
@ -172,7 +173,7 @@ class GenerateCandidates {
|
|||
// subtract the output and buf, and allow using more than the actual L1
|
||||
// size. This results in an overestimate, and the loop below will propose
|
||||
// the next few smaller values for the autotuner to evaluate.
|
||||
const size_t bytes_ab = Allocator::L1Bytes() * 3;
|
||||
const size_t bytes_ab = allocator_.L1Bytes() * 3;
|
||||
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
|
||||
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
|
||||
kc_max =
|
||||
|
|
@ -220,8 +221,8 @@ class GenerateCandidates {
|
|||
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
|
||||
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
|
||||
// partial.
|
||||
const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes();
|
||||
size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc);
|
||||
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
|
||||
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
|
||||
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
|
||||
HWY_DASSERT(mc_max != 0);
|
||||
mc_max = HWY_MIN(mc_max, M_);
|
||||
|
|
@ -264,7 +265,7 @@ class GenerateCandidates {
|
|||
// Otherwise, leave it unbounded.
|
||||
if (M_ > mr) {
|
||||
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
|
||||
nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc);
|
||||
nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc);
|
||||
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
|
||||
}
|
||||
HWY_DASSERT(nc_max != 0);
|
||||
|
|
@ -351,6 +352,7 @@ class GenerateCandidates {
|
|||
}
|
||||
}
|
||||
|
||||
const Allocator2& allocator_;
|
||||
const size_t M_;
|
||||
const size_t K_;
|
||||
const size_t N_;
|
||||
|
|
@ -370,25 +372,26 @@ class GenerateCandidates {
|
|||
} // namespace
|
||||
|
||||
// Facade to avoid exposing `GenerateCandidates` in the header.
|
||||
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
bool print_config) {
|
||||
return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np,
|
||||
print_config)();
|
||||
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
|
||||
ranges_np, print_config)();
|
||||
}
|
||||
|
||||
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
|
||||
// memory accesses or false sharing, unless there are insufficient per-package
|
||||
// rows for that.
|
||||
static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
|
||||
size_t num_packages) {
|
||||
size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC;
|
||||
static size_t NPMultiple(const Allocator2& allocator, size_t N,
|
||||
size_t sizeof_TC, size_t nr, size_t num_packages) {
|
||||
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
|
||||
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
|
||||
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
|
||||
// choose a smaller multiple.
|
||||
if (N % (np_multiple * num_packages)) {
|
||||
const size_t min_multiple = Allocator::LineBytes() / sizeof_TC;
|
||||
const size_t min_multiple = allocator.LineBytes() / sizeof_TC;
|
||||
np_multiple =
|
||||
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
|
||||
if (HWY_UNLIKELY(np_multiple == 0)) {
|
||||
|
|
@ -408,16 +411,14 @@ static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
|
|||
|
||||
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
|
||||
size_t sizeof_TC, size_t nr) const {
|
||||
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
|
||||
return StaticPartition(IndexRange(0, N), num_packages,
|
||||
NPMultiple(N, sizeof_TC, nr, num_packages));
|
||||
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
|
||||
return StaticPartition(
|
||||
IndexRange(0, N), num_packages,
|
||||
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
|
||||
}
|
||||
|
||||
MatMulEnv::MatMulEnv(const BoundedTopology& topology, NestedPools& pools)
|
||||
: parallel(topology, pools), storage(parallel) {
|
||||
// Ensure Allocator:Init was called.
|
||||
HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0);
|
||||
|
||||
MatMulEnv::MatMulEnv(ThreadingContext2& ctx)
|
||||
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
|
||||
char cpu100[100];
|
||||
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
|
||||
}
|
||||
|
|
|
|||
130
ops/matmul.h
130
ops/matmul.h
|
|
@ -24,11 +24,9 @@
|
|||
#include <vector>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/compress.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "util/topology.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/bit_set.h"
|
||||
|
|
@ -51,28 +49,25 @@ class MMParallel {
|
|||
public:
|
||||
static constexpr size_t kMaxPackages = 4;
|
||||
|
||||
// Both references must outlive this object.
|
||||
MMParallel(const BoundedTopology& topology, NestedPools& pools)
|
||||
: topology_(topology), pools_(pools) {
|
||||
HWY_DASSERT(pools_.NumPackages() <= kMaxPackages);
|
||||
// `ctx` must outlive this object.
|
||||
MMParallel(ThreadingContext2& ctx) : ctx_(ctx) {
|
||||
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
|
||||
}
|
||||
|
||||
// Used by tests.
|
||||
NestedPools& Pools() { return pools_; }
|
||||
|
||||
// Initial static partitioning of B rows across packages.
|
||||
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
|
||||
size_t sizeof_TC, size_t nr) const;
|
||||
|
||||
// For `BindB` and `BindC`.
|
||||
size_t Node(size_t pkg_idx) const {
|
||||
return topology_.GetCluster(pkg_idx, 0).Node();
|
||||
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
|
||||
}
|
||||
|
||||
// Calls `func(pkg_idx)` for each package in parallel.
|
||||
template <class Func>
|
||||
void ForPkg(const size_t max_packages, const Func& func) {
|
||||
pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()),
|
||||
ctx_.pools.AllPackages().Run(
|
||||
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
|
||||
[&](uint64_t task, size_t pkg_idx) {
|
||||
HWY_DASSERT(task == pkg_idx);
|
||||
(void)task;
|
||||
|
|
@ -87,10 +82,10 @@ class MMParallel {
|
|||
size_t pkg_idx, const Func& func) {
|
||||
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
|
||||
// Single cluster: parallel-for over static partition of `range_np`.
|
||||
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
|
||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
if (num_clusters == 1) {
|
||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0);
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
return ParallelizeOneRange(
|
||||
|
|
@ -106,7 +101,7 @@ class MMParallel {
|
|||
ParallelizeOneRange(
|
||||
nx_ranges, all_clusters,
|
||||
[&](const IndexRange& nx_range, const size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
|
||||
const IndexRangePartition worker_ranges = StaticPartition(
|
||||
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
|
||||
|
|
@ -122,14 +117,14 @@ class MMParallel {
|
|||
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
|
||||
const IndexRangePartition& ranges_nc, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
|
||||
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
|
||||
// `all_clusters` is a pool with one worker per cluster in a package.
|
||||
const size_t num_clusters = all_clusters.NumWorkers();
|
||||
// Single (big) cluster: collapse two range indices into one parallel-for
|
||||
// to reduce the number of fork-joins.
|
||||
if (num_clusters == 1) {
|
||||
const size_t cluster_idx = 0;
|
||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
// Low-batch: avoid Divide/Remainder.
|
||||
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
|
||||
return ParallelizeOneRange(
|
||||
|
|
@ -150,7 +145,7 @@ class MMParallel {
|
|||
ParallelizeOneRange(
|
||||
ranges_nc, all_clusters,
|
||||
[&](const IndexRange range_nc, size_t cluster_idx) {
|
||||
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
|
||||
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
|
||||
ParallelizeOneRange(
|
||||
ranges_mc, cluster,
|
||||
[&](const IndexRange& range_mc, size_t /*thread*/) {
|
||||
|
|
@ -163,32 +158,32 @@ class MMParallel {
|
|||
template <class Func>
|
||||
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
|
||||
const Func& func) {
|
||||
pools_.Pool(pkg_idx).Run(
|
||||
ctx_.pools.Pool(pkg_idx).Run(
|
||||
range_mc.begin(), range_mc.end(),
|
||||
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
|
||||
}
|
||||
|
||||
private:
|
||||
const BoundedTopology& topology_;
|
||||
NestedPools& pools_;
|
||||
ThreadingContext2& ctx_;
|
||||
};
|
||||
|
||||
template <typename TC> // BF16/float for C, double for partial
|
||||
void BindC(size_t M, const RowPtr<TC>& C, MMParallel& parallel) {
|
||||
if (!Allocator::ShouldBind()) return;
|
||||
void BindC(const Allocator2& allocator, size_t M, const RowPtr<TC>& C,
|
||||
MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
|
||||
const size_t quantum = Allocator::QuantumBytes() / sizeof(TC);
|
||||
const size_t quantum = allocator.Quantum<TC>();
|
||||
bool ok = true;
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
for (size_t im = 0; im < M; ++im) {
|
||||
// BindRowsToPackageNodes may not be page-aligned.
|
||||
// `BindMemory` requires page alignment.
|
||||
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
|
||||
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
|
||||
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
|
||||
ok &= allocator.BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
|
||||
node);
|
||||
}
|
||||
}
|
||||
|
|
@ -212,38 +207,42 @@ class MMStorage {
|
|||
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
|
||||
static constexpr size_t kMaxKC = 8 * 1024;
|
||||
|
||||
explicit MMStorage(MMParallel& parallel) {
|
||||
MMStorage(const Allocator2& allocator, MMParallel& parallel)
|
||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||
// one instance of the maximum matrix extents because threads write at
|
||||
// false-sharing-free granularity.
|
||||
: partial_storage_(
|
||||
AllocateAlignedRows<double>(allocator, Extents2D(kMaxM, kMaxN))),
|
||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
partial_(allocator, partial_storage_.All(), kMaxN,
|
||||
StrideForCyclicOffsets(kMaxN, allocator.Quantum<double>())) {
|
||||
// Per-package allocation so each can decompress A into its own copy.
|
||||
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
|
||||
pkg_A_[pkg_idx] = AllocateAlignedRows<BF16>(Extents2D(kMaxM, kMaxK));
|
||||
pkg_A_[pkg_idx] =
|
||||
AllocateAlignedRows<BF16>(allocator, Extents2D(kMaxM, kMaxK));
|
||||
|
||||
if (Allocator::ShouldBind()) {
|
||||
if (allocator.ShouldBind()) {
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(),
|
||||
if (!allocator.BindMemory(pkg_A_[pkg_idx].All(),
|
||||
pkg_A_[pkg_idx].NumBytes(), node)) {
|
||||
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Per-worker copies of `partial` would be wasteful. We instead allocate
|
||||
// one instance of the maximum matrix extents because threads write at
|
||||
// false-sharing-free granularity.
|
||||
partial_storage_ = AllocateAlignedRows<double>(Extents2D(kMaxM, kMaxN));
|
||||
// Same stride independent of the actual C.Cols() so we can pre-bind.
|
||||
partial_ = RowPtrD(partial_storage_.All(), kMaxN,
|
||||
StrideForCyclicOffsets<double>(kMaxN));
|
||||
// Avoid cross-package accesses.
|
||||
BindC(kMaxM, partial_, parallel);
|
||||
BindC(allocator, kMaxM, partial_, parallel);
|
||||
}
|
||||
|
||||
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
|
||||
// non-const, because `RowPtr` requires a non-const pointer.
|
||||
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) {
|
||||
RowPtrBF A(const Allocator2& allocator, size_t pkg_idx,
|
||||
const Extents2D& extents) {
|
||||
HWY_DASSERT(extents.rows <= kMaxM);
|
||||
HWY_DASSERT(extents.cols <= kMaxK);
|
||||
const size_t stride = StrideForCyclicOffsets<BF16>(extents.cols);
|
||||
return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride);
|
||||
const size_t stride =
|
||||
StrideForCyclicOffsets(extents.cols, allocator.Quantum<BF16>());
|
||||
return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride);
|
||||
}
|
||||
|
||||
RowPtrD Partial() const { return partial_; }
|
||||
|
|
@ -431,13 +430,15 @@ class MMConfig {
|
|||
static_assert(sizeof(MMConfig) == 32); // for faster indexing
|
||||
#pragma pack(pop)
|
||||
|
||||
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
|
||||
size_t sizeof_TC, size_t max_mr, size_t nr,
|
||||
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
|
||||
size_t K, size_t N, size_t sizeof_TC,
|
||||
size_t max_mr, size_t nr,
|
||||
const IndexRangePartition& ranges_np,
|
||||
bool print_config);
|
||||
|
||||
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
|
||||
// main MatMul autotuner.
|
||||
// TODO: replace with hwy/auto_tune.h.
|
||||
template <typename TConfig>
|
||||
class MMAutoTune {
|
||||
public:
|
||||
|
|
@ -560,11 +561,11 @@ class MMKeys {
|
|||
}
|
||||
|
||||
// Must only be called if not already present in `Keys()`.
|
||||
void Append(Key key) {
|
||||
void Append(Key key, const Allocator2& allocator) {
|
||||
// Dynamic allocation because the test checks many more dimensions than
|
||||
// would be reasonable to pre-allocate. DIY for alignment and padding.
|
||||
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
|
||||
const size_t NU64 = Allocator::VectorBytes() / sizeof(Key);
|
||||
const size_t NU64 = allocator.VectorBytes() / sizeof(Key);
|
||||
// Start at one vector so the size is always a multiple of N.
|
||||
if (HWY_UNLIKELY(capacity_ == 0)) {
|
||||
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
|
||||
|
|
@ -604,10 +605,12 @@ struct MMPerKey {
|
|||
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
|
||||
};
|
||||
|
||||
// Stores state shared across MatMul calls. Non-copyable.
|
||||
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
|
||||
// `MatMulEnv`.
|
||||
struct MatMulEnv {
|
||||
explicit MatMulEnv(const BoundedTopology& topology, NestedPools& pools);
|
||||
explicit MatMulEnv(ThreadingContext2& ctx);
|
||||
|
||||
ThreadingContext2& ctx;
|
||||
bool have_timer_stop = false;
|
||||
|
||||
// Enable binding: disabled in Gemma until tensors support it, enabled in
|
||||
|
|
@ -684,8 +687,9 @@ struct MMZone {
|
|||
// `ofs` required for compressed T.
|
||||
template <typename T>
|
||||
struct ConstMat {
|
||||
ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0)
|
||||
: ptr(ptr), extents(extents), stride(stride), ofs(ofs) {
|
||||
ConstMat() = default;
|
||||
ConstMat(const T* ptr, Extents2D extents, size_t stride)
|
||||
: ptr(ptr), extents(extents), stride(stride), ofs(0) {
|
||||
HWY_DASSERT(ptr != nullptr);
|
||||
HWY_DASSERT(stride >= extents.cols);
|
||||
}
|
||||
|
|
@ -717,15 +721,17 @@ struct ConstMat {
|
|||
float scale = 1.0f;
|
||||
|
||||
// Offset to add to `ptr`; separate because T=NuqStream does not support
|
||||
// pointer arithmetic.
|
||||
// pointer arithmetic. This is in units of weights, and does not have anything
|
||||
// to do with the interleaved NUQ tables. It should be computed via `Row()`
|
||||
// to take into account the stride.
|
||||
size_t ofs;
|
||||
};
|
||||
|
||||
// For deducing T.
|
||||
template <typename T>
|
||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride,
|
||||
size_t ofs = 0) {
|
||||
return ConstMat<T>(ptr, extents, stride, ofs);
|
||||
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
|
||||
size_t stride) {
|
||||
return ConstMat<T>(ptr, extents, stride);
|
||||
}
|
||||
|
||||
// For A argument to MatMul (activations).
|
||||
|
|
@ -739,21 +745,21 @@ ConstMat<T> ConstMatFromBatch(size_t batch_size,
|
|||
}
|
||||
|
||||
template <typename T>
|
||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
|
||||
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
|
||||
ConstMat<T> mat =
|
||||
MakeConstMat(const_cast<T*>(m.data()), m.Extents(), m.Stride(), ofs);
|
||||
mat.scale = m.scale();
|
||||
MakeConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride());
|
||||
mat.scale = m.Scale();
|
||||
return mat;
|
||||
}
|
||||
|
||||
template <typename TB>
|
||||
void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
|
||||
MMParallel& parallel) {
|
||||
if (!Allocator::ShouldBind()) return;
|
||||
void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC,
|
||||
const ConstMat<TB>& B, MMParallel& parallel) {
|
||||
if (!allocator.ShouldBind()) return;
|
||||
|
||||
const IndexRangePartition ranges_np =
|
||||
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
|
||||
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
|
||||
const size_t quantum = allocator.Quantum<TB>();
|
||||
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
|
||||
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
|
||||
const size_t node = parallel.Node(pkg_idx);
|
||||
|
|
@ -765,7 +771,7 @@ void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
|
|||
begin = hwy::RoundUpTo(begin, quantum);
|
||||
end = hwy::RoundDownTo(end, quantum);
|
||||
if (HWY_LIKELY(begin != end)) {
|
||||
Allocator::BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
||||
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,29 +15,25 @@
|
|||
|
||||
// End to end test of MatMul, comparing against a reference implementation.
|
||||
|
||||
#include "hwy/detect_compiler_arch.h"
|
||||
#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
|
||||
// double-precision support.
|
||||
// double-precision support, and older x86 to speed up builds.
|
||||
#if HWY_ARCH_ARM_V7
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
|
||||
#else
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4)
|
||||
#endif
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/shared.h"
|
||||
#include "ops/matmul.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -48,9 +44,9 @@
|
|||
#include "hwy/highway.h"
|
||||
// After highway.h
|
||||
#include "compression/compress-inl.h"
|
||||
#include "compression/test_util-inl.h"
|
||||
#include "ops/dot-inl.h"
|
||||
#include "ops/matmul-inl.h"
|
||||
#include "hwy/tests/test_util-inl.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
namespace gcpp {
|
||||
|
|
@ -60,57 +56,6 @@ extern int64_t first_target;
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
|
||||
|
||||
template <typename MatT>
|
||||
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
|
||||
|
||||
// Generates inputs: deterministic, within max SfpStream range.
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
HWY_ASSERT(content);
|
||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(r * extents.cols + c) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
content[r * extents.cols + c] = f;
|
||||
}
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
||||
return mat;
|
||||
}
|
||||
|
||||
// extents describes the transposed matrix.
|
||||
template <typename MatT>
|
||||
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
|
||||
hwy::ThreadPool& pool) {
|
||||
gcpp::CompressWorkingSet ws;
|
||||
auto mat =
|
||||
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
|
||||
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
|
||||
const float scale = SfpStream::kMax / (mat->NumElements());
|
||||
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
|
||||
for (size_t c = 0; c < extents.cols; c++) {
|
||||
float f = static_cast<float>(c * extents.rows + r) * scale;
|
||||
if ((r + c) & 1) f = -f; // Also generate some negative values.
|
||||
content[r * extents.cols + c] = f;
|
||||
}
|
||||
});
|
||||
|
||||
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
|
||||
// Arbitrary value, different from 1, must match GenerateMat.
|
||||
mat->set_scale(0.6f);
|
||||
return mat;
|
||||
}
|
||||
|
||||
// Returns 1-norm, used for estimating tolerable numerical differences.
|
||||
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
|
||||
double max_row_abs_sum = 0.0;
|
||||
|
|
@ -141,16 +86,19 @@ float MaxAbs(const RowVectorBatch<float>& a) {
|
|||
template <typename TA, typename TB, typename TC>
|
||||
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
|
||||
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const hn::ScalableTag<float> df;
|
||||
const size_t cols = A.extents.cols;
|
||||
const size_t B_rows = B.extents.rows;
|
||||
// Round up for DecompressAndZeroPad.
|
||||
RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
|
||||
RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
|
||||
RowVectorBatch<float> a_batch =
|
||||
AllocateAlignedRows<float>(allocator, A.extents);
|
||||
RowVectorBatch<float> b_trans_batch =
|
||||
AllocateAlignedRows<float>(allocator, B.extents);
|
||||
RowVectorBatch<float> c_batch =
|
||||
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
|
||||
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
|
||||
RowVectorBatch<float> c_slow_batch =
|
||||
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
|
||||
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
|
||||
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
|
||||
for (size_t m = 0; m < A.extents.rows; ++m) {
|
||||
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
|
||||
|
|
@ -224,7 +172,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
|||
const IndexRange all_rows_c(0, A.Extents().rows);
|
||||
const IndexRange all_cols_c(0, C.Cols());
|
||||
|
||||
NestedPools& pools = env.parallel.Pools();
|
||||
NestedPools& pools = env.ctx.pools;
|
||||
hwy::ThreadPool& all_packages = pools.AllPackages();
|
||||
const IndexRangePartition get_row_c =
|
||||
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
|
||||
|
|
@ -232,7 +180,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
|
|||
get_row_c, all_packages,
|
||||
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
|
||||
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
|
||||
const size_t multiple = Allocator::QuantumBytes() / sizeof(TB);
|
||||
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
|
||||
const IndexRangePartition get_col_c =
|
||||
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
|
||||
ParallelizeOneRange(
|
||||
|
|
@ -262,7 +210,8 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
|
|||
template <typename TA, typename TB = TA, typename TC = float>
|
||||
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
||||
MatMulEnv& env, int line) {
|
||||
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
|
||||
const Allocator2& allocator = env.ctx.allocator;
|
||||
hwy::ThreadPool& pool = env.ctx.pools.Pool();
|
||||
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
|
||||
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
|
||||
TypeName<TC>());
|
||||
|
|
@ -274,24 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
|
|||
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
|
||||
const Extents2D C_extents(rows_ac, cols_bc);
|
||||
|
||||
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
|
||||
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
|
||||
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
|
||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
|
||||
HWY_ASSERT(a && b_trans);
|
||||
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
|
||||
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
|
||||
RowVectorBatch<TC> c_slow_batch =
|
||||
AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
|
||||
|
||||
std::unique_ptr<MatStorageT<float>> add_storage;
|
||||
if (add) {
|
||||
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
|
||||
HWY_ASSERT(add_storage);
|
||||
add_storage->set_scale(1.0f);
|
||||
}
|
||||
MatStorageT<float> add_storage =
|
||||
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
|
||||
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
|
||||
add_storage.SetScale(1.0f);
|
||||
|
||||
const auto A = ConstMatFromWeights(*a);
|
||||
const auto B = ConstMatFromWeights(*b_trans);
|
||||
const float* add_row = add ? add_storage->data_scale1() : nullptr;
|
||||
const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
|
||||
const auto A = ConstMatFromWeights(a);
|
||||
const auto B = ConstMatFromWeights(b_trans);
|
||||
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
|
||||
const RowPtr<TC> C_slow = RowPtrFromBatch(allocator, c_slow_batch);
|
||||
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
|
||||
|
||||
MatMulSlow(A, B, add_row, env, C_slow);
|
||||
// A few reps to get coverage of the various autotuned code paths.
|
||||
|
|
@ -312,22 +259,24 @@ void TestTiny() {
|
|||
if (HWY_TARGET != first_target) return;
|
||||
|
||||
for (size_t max_packages : {1, 2}) {
|
||||
const BoundedTopology topology(BoundedSlice(0, max_packages));
|
||||
Allocator::Init(topology, /*enable_bind=*/true);
|
||||
const size_t max_threads = 0; // no limit
|
||||
NestedPools pools(topology, max_threads, Tristate::kDefault);
|
||||
ThreadingContext2::ThreadHostileInvalidate();
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.bind = Tristate::kTrue;
|
||||
threading_args.max_packages = max_packages;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext2::Get());
|
||||
NestedPools& pools = env.ctx.pools;
|
||||
|
||||
#if GEMMA_DISABLE_TOPOLOGY
|
||||
if (max_packages == 2) break; // we only have one package
|
||||
#else
|
||||
// If less than the limit, we have already tested all num_packages.
|
||||
if (topology.FullTopology().packages.size() < max_packages) break;
|
||||
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
|
||||
#endif
|
||||
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
|
||||
topology.TopologyString(), pools.PinString());
|
||||
env.ctx.topology.TopologyString(), pools.PinString());
|
||||
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
pools.MaybeStartSpinning(use_spinning);
|
||||
MatMulEnv env(topology, pools);
|
||||
pools.MaybeStartSpinning(threading_args.spin);
|
||||
|
||||
for (size_t M = 1; M <= 12; ++M) {
|
||||
for (size_t K = 1; K <= 64; K *= 2) {
|
||||
|
|
@ -336,7 +285,7 @@ void TestTiny() {
|
|||
}
|
||||
}
|
||||
}
|
||||
pools.MaybeStopSpinning(use_spinning);
|
||||
pools.MaybeStopSpinning(threading_args.spin);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -347,12 +296,13 @@ void TestAllMatMul() {
|
|||
return;
|
||||
}
|
||||
|
||||
const BoundedTopology topology;
|
||||
Allocator::Init(topology, /*enable_bind=*/true);
|
||||
NestedPools pools(topology);
|
||||
Tristate use_spinning = Tristate::kDefault;
|
||||
pools.MaybeStartSpinning(use_spinning);
|
||||
MatMulEnv env(topology, pools);
|
||||
ThreadingContext2::ThreadHostileInvalidate();
|
||||
ThreadingArgs threading_args;
|
||||
threading_args.bind = Tristate::kTrue;
|
||||
ThreadingContext2::SetArgs(threading_args);
|
||||
MatMulEnv env(ThreadingContext2::Get());
|
||||
NestedPools& pools = env.ctx.pools;
|
||||
pools.MaybeStartSpinning(threading_args.spin);
|
||||
|
||||
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
|
||||
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
|
||||
|
|
@ -417,6 +367,8 @@ void TestAllMatMul() {
|
|||
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
|
||||
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
|
||||
|
||||
pools.MaybeStopSpinning(threading_args.spin);
|
||||
}
|
||||
|
||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||
|
|
|
|||
|
|
@ -50,8 +50,7 @@ template <class ArrayT, typename VT>
|
|||
HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned,
|
||||
size_t num) {
|
||||
const hn::ScalableTag<VT> d;
|
||||
return w.scale() * Dot(d, MakeConstSpan(w.data(), w.NumElements()), w_ofs,
|
||||
vec_aligned, num);
|
||||
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
|
||||
}
|
||||
|
||||
// Simple version without tiling nor threading, but two offsets/outputs and
|
||||
|
|
|
|||
|
|
@ -27,12 +27,13 @@
|
|||
#include <type_traits> // std::enable_if_t
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // TokenAndProb
|
||||
#include "util/mat.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/sort/order.h"
|
||||
#include "hwy/contrib/sort/vqsort.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
#include "hwy/profiler.h"
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
|
||||
|
|
@ -807,12 +808,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
|
|||
// Each output row is the average of a 4x4 block of input rows
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
|
||||
Extents2D extents = input.Extents();
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const Extents2D extents = input.Extents();
|
||||
// Input validation
|
||||
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
|
||||
// Create output with 256 rows and same number of columns
|
||||
const size_t out_rows = 256; // 16 * 16 = 256 output rows
|
||||
RowVectorBatch<T> result(Extents2D{out_rows, extents.cols});
|
||||
RowVectorBatch<T> result(allocator, Extents2D(out_rows, extents.cols));
|
||||
const size_t input_dim = 64; // Input is 64×64
|
||||
const size_t output_dim = 16; // Output is 16×16
|
||||
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {
|
||||
|
|
|
|||
|
|
@ -21,14 +21,16 @@
|
|||
#include <cmath>
|
||||
|
||||
#include "util/allocator.h"
|
||||
#include "util/mat.h"
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
|
||||
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
|
||||
const Allocator2& allocator, size_t qkv_dim, bool half_rope,
|
||||
double base_frequency = 10000.0) {
|
||||
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
|
||||
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
|
||||
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));
|
||||
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
|
||||
const double freq_exponents =
|
||||
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);
|
||||
|
|
|
|||
|
|
@ -31,14 +31,12 @@
|
|||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include "compression/compress.h" // BF16
|
||||
#include "gemma/common.h"
|
||||
#include "gemma/configs.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/app.h"
|
||||
#include "util/basics.h" // BF16
|
||||
#include "util/mat.h" // RowVectorBatch
|
||||
#include "util/test_util.h"
|
||||
#include "util/threading.h"
|
||||
#include "hwy/base.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/tests/hwy_gtest.h"
|
||||
|
||||
// clang-format off
|
||||
|
|
@ -388,13 +386,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
|
|||
}
|
||||
|
||||
void TestRopeAndMulBy() {
|
||||
AppArgs app;
|
||||
BoundedTopology topology = CreateTopology(app);
|
||||
NestedPools pools = CreatePools(topology, app);
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
|
||||
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
|
||||
int dim_qkv = config.layer_configs[0].qkv_dim;
|
||||
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
|
||||
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv));
|
||||
|
||||
std::mt19937 gen;
|
||||
gen.seed(0x12345678);
|
||||
|
|
@ -412,8 +408,8 @@ void TestRopeAndMulBy() {
|
|||
std::vector<float> qactual(dim_qkv);
|
||||
std::vector<float> kexpected(dim_qkv);
|
||||
std::vector<float> kactual(dim_qkv);
|
||||
RowVectorBatch<float> inv_timescale = gcpp::CreateInvTimescale(
|
||||
config.layer_configs[0].qkv_dim,
|
||||
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
|
||||
allocator, config.layer_configs[0].qkv_dim,
|
||||
config.layer_configs[0].post_qk == PostQKType::HalfRope);
|
||||
// Assert VectorizedRope computation is same as regular rope at different pos.
|
||||
for (int pos = 1; pos < 500; pos++) {
|
||||
|
|
|
|||
|
|
@ -49,8 +49,8 @@ class PaliGemmaTest : public ::testing::Test {
|
|||
};
|
||||
|
||||
void PaliGemmaTest::InitVit(const std::string& path) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
Gemma& model = *(s_env->GetGemma());
|
||||
image_tokens_ =
|
||||
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
|
|
@ -64,7 +64,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
|
|||
}
|
||||
|
||||
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
||||
Gemma& model = *(s_env->GetModel());
|
||||
Gemma& model = *(s_env->GetGemma());
|
||||
s_env->MutableGen().seed(0x12345678);
|
||||
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
|
||||
.gen = &s_env->MutableGen(),
|
||||
|
|
@ -92,7 +92,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
|
|||
}
|
||||
|
||||
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
std::string path = "paligemma/testdata/image.ppm";
|
||||
InitVit(path);
|
||||
for (size_t i = 0; i < num_questions; ++i) {
|
||||
|
|
@ -104,7 +104,7 @@ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
|
|||
}
|
||||
|
||||
TEST_F(PaliGemmaTest, General) {
|
||||
ASSERT_NE(s_env->GetModel(), nullptr);
|
||||
ASSERT_NE(s_env->GetGemma(), nullptr);
|
||||
static const char* kQA_3B_mix_224[][2] = {
|
||||
{"describe this image",
|
||||
"A large building with two towers stands tall on the water's edge."},
|
||||
|
|
@ -124,7 +124,7 @@ TEST_F(PaliGemmaTest, General) {
|
|||
};
|
||||
const char* (*qa)[2];
|
||||
size_t num;
|
||||
switch (s_env->GetModel()->Info().model) {
|
||||
switch (s_env->GetGemma()->Info().model) {
|
||||
case Model::PALIGEMMA_224:
|
||||
qa = kQA_3B_mix_224;
|
||||
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
|
||||
|
|
@ -135,7 +135,7 @@ TEST_F(PaliGemmaTest, General) {
|
|||
break;
|
||||
default:
|
||||
FAIL() << "Unsupported model: "
|
||||
<< s_env->GetModel()->GetModelConfig().model_name;
|
||||
<< s_env->GetGemma()->GetModelConfig().model_name;
|
||||
break;
|
||||
}
|
||||
TestQuestions(qa, num);
|
||||
|
|
|
|||
|
|
@ -21,12 +21,12 @@ pybind_extension(
|
|||
name = "gemma",
|
||||
srcs = ["gemma_py.cc"],
|
||||
deps = [
|
||||
"//:app",
|
||||
"//:allocator",
|
||||
"//:benchmark_helper",
|
||||
"//:gemma_args",
|
||||
"//:gemma_lib",
|
||||
"//compression:sfp",
|
||||
"@highway//:hwy",
|
||||
"@highway//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,9 +32,9 @@
|
|||
#include "compression/shared.h"
|
||||
#include "evals/benchmark_helper.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "util/app.h"
|
||||
#include "gemma/gemma_args.h"
|
||||
#include "util/allocator.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
||||
namespace py = pybind11;
|
||||
|
||||
|
|
@ -48,8 +48,9 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
|
|||
class GemmaModel {
|
||||
public:
|
||||
GemmaModel(const gcpp::LoaderArgs& loader,
|
||||
const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app)
|
||||
: gemma_(loader, inference, app), last_prob_(0.0f) {}
|
||||
const gcpp::InferenceArgs& inference,
|
||||
const gcpp::ThreadingArgs& threading)
|
||||
: gemma_(threading, loader, inference), last_prob_(0.0f) {}
|
||||
|
||||
// Generates a single example, given a prompt and a callback to stream the
|
||||
// generated tokens.
|
||||
|
|
@ -168,7 +169,8 @@ class GemmaModel {
|
|||
// Generate* will use this image. Throws an error for other models.
|
||||
void SetImage(const py::array_t<float, py::array::c_style |
|
||||
py::array::forcecast>& image) {
|
||||
gcpp::Gemma& model = *(gemma_.GetModel());
|
||||
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
|
||||
gcpp::Gemma& model = *(gemma_.GetGemma());
|
||||
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
|
||||
throw std::invalid_argument("Not a PaliGemma model.");
|
||||
}
|
||||
|
|
@ -183,8 +185,8 @@ class GemmaModel {
|
|||
c_image.Set(height, width, ptr);
|
||||
const size_t image_size = model.GetModelConfig().vit_config.image_size;
|
||||
c_image.Resize(image_size, image_size);
|
||||
image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D(
|
||||
model.GetModelConfig().vit_config.seq_len,
|
||||
image_tokens_ = gcpp::ImageTokens(
|
||||
allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len,
|
||||
model.GetModelConfig().model_dim));
|
||||
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
|
||||
.verbosity = 0};
|
||||
|
|
@ -199,7 +201,7 @@ class GemmaModel {
|
|||
if (image_tokens_.Cols() == 0) {
|
||||
throw std::invalid_argument("No image set.");
|
||||
}
|
||||
gcpp::Gemma& model = *(gemma_.GetModel());
|
||||
gcpp::Gemma& model = *(gemma_.GetGemma());
|
||||
gemma_.MutableGen().seed(seed);
|
||||
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
|
||||
config.max_generated_tokens = max_generated_tokens;
|
||||
|
|
@ -247,7 +249,7 @@ class GemmaModel {
|
|||
return gemma_.StringFromTokens(token_ids);
|
||||
}
|
||||
|
||||
bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; }
|
||||
bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; }
|
||||
|
||||
private:
|
||||
gcpp::GemmaEnv gemma_;
|
||||
|
|
@ -267,7 +269,7 @@ PYBIND11_MODULE(gemma, mod) {
|
|||
loader.weight_type_str = weight_type;
|
||||
gcpp::InferenceArgs inference;
|
||||
inference.max_generated_tokens = 512;
|
||||
gcpp::AppArgs app;
|
||||
gcpp::ThreadingArgs app;
|
||||
app.max_threads = max_threads;
|
||||
auto gemma =
|
||||
std::make_unique<GemmaModel>(loader, inference, app);
|
||||
|
|
|
|||
|
|
@ -130,233 +130,6 @@ size_t DetectTotalMiB(size_t page_bytes) {
|
|||
|
||||
} // namespace
|
||||
|
||||
static size_t line_bytes_;
|
||||
static size_t vector_bytes_;
|
||||
static size_t step_bytes_;
|
||||
static size_t quantum_bytes_;
|
||||
static size_t quantum_steps_;
|
||||
static size_t l1_bytes_;
|
||||
static size_t l2_bytes_;
|
||||
static size_t l3_bytes_;
|
||||
static bool should_bind_ = false;
|
||||
|
||||
size_t Allocator::LineBytes() { return line_bytes_; }
|
||||
size_t Allocator::VectorBytes() { return vector_bytes_; }
|
||||
size_t Allocator::StepBytes() { return step_bytes_; }
|
||||
size_t Allocator::QuantumBytes() { return quantum_bytes_; }
|
||||
size_t Allocator::QuantumSteps() { return quantum_steps_; }
|
||||
size_t Allocator::L1Bytes() { return l1_bytes_; }
|
||||
size_t Allocator::L2Bytes() { return l2_bytes_; }
|
||||
size_t Allocator::L3Bytes() { return l3_bytes_; }
|
||||
bool Allocator::ShouldBind() { return should_bind_; }
|
||||
|
||||
void Allocator::Init(const BoundedTopology& topology, bool enable_bind) {
|
||||
line_bytes_ = DetectLineBytes();
|
||||
vector_bytes_ = hwy::VectorBytes();
|
||||
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
|
||||
quantum_bytes_ = step_bytes_; // may overwrite below
|
||||
|
||||
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
|
||||
if (const hwy::Cache* caches = hwy::DataCaches()) {
|
||||
l1_bytes_ = caches[1].size_kib << 10;
|
||||
l2_bytes_ = caches[2].size_kib << 10;
|
||||
l3_bytes_ = (caches[3].size_kib << 10) * caches[3].cores_sharing;
|
||||
} else { // Unknown, make reasonable assumptions.
|
||||
l1_bytes_ = 32 << 10;
|
||||
l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10;
|
||||
}
|
||||
if (l3_bytes_ == 0) {
|
||||
l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10;
|
||||
}
|
||||
|
||||
// Prerequisites for binding:
|
||||
// - supported by the OS (currently Linux only),
|
||||
// - the page size is known and 'reasonably small', preferably less than
|
||||
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
|
||||
// - we successfully detected topology and there are multiple nodes;
|
||||
// - there are multiple packages, because we shard by package_idx.
|
||||
if constexpr (GEMMA_BIND) {
|
||||
const size_t page_bytes = DetectPageSize();
|
||||
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
|
||||
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
|
||||
if (enable_bind) {
|
||||
// Ensure pages meet the alignment requirements of `AllocBytes`.
|
||||
HWY_ASSERT(page_bytes >= quantum_bytes_);
|
||||
quantum_bytes_ = page_bytes;
|
||||
// Ensure MaxQuantumBytes() is an upper bound.
|
||||
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
|
||||
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
|
||||
should_bind_ = true;
|
||||
} else {
|
||||
HWY_WARN(
|
||||
"Multiple sockets but binding disabled. This reduces speed; "
|
||||
"set or remove enable_bind to avoid this warning.");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0);
|
||||
quantum_steps_ = quantum_bytes_ / step_bytes_;
|
||||
}
|
||||
|
||||
Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) {
|
||||
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
||||
// defends against 2K aliasing.
|
||||
if (!should_bind_) {
|
||||
// Perf warning if Highway's alignment is less than we want.
|
||||
if (HWY_ALIGNMENT < QuantumBytes()) {
|
||||
HWY_WARN(
|
||||
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
|
||||
"are huge, enable GEMMA_BIND to avoid this warning.",
|
||||
HWY_ALIGNMENT, QuantumBytes());
|
||||
}
|
||||
auto p = hwy::AllocateAligned<uint8_t>(bytes);
|
||||
// The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the
|
||||
// alignment scheme in aligned_allocator.cc and does not work for
|
||||
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
||||
// pointer in our own deleter.
|
||||
auto call_free = [](void* ptr, size_t /*bytes*/) {
|
||||
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
|
||||
};
|
||||
return PtrAndDeleter{p.release(), DeleterFree(call_free, bytes)};
|
||||
}
|
||||
|
||||
// Binding, or large vector/cache line size: use platform-specific allocator.
|
||||
|
||||
#if HWY_OS_LINUX && !defined(__ANDROID_API__)
|
||||
// `move_pages` is documented to require an anonymous/private mapping or
|
||||
// `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`.
|
||||
// `Init` verified that the page size is a multiple of `QuantumBytes()`.
|
||||
const int prot = PROT_READ | PROT_WRITE;
|
||||
const int flags = MAP_ANONYMOUS | MAP_PRIVATE;
|
||||
const int fd = -1;
|
||||
// Encourage transparent hugepages by rounding up to a multiple of 2 MiB.
|
||||
bytes = hwy::RoundUpTo(bytes, 2ull << 20);
|
||||
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
||||
if (p == MAP_FAILED) p = nullptr;
|
||||
const auto call_munmap = [](void* ptr, size_t bytes) {
|
||||
const int ret = munmap(ptr, bytes);
|
||||
HWY_ASSERT(ret == 0);
|
||||
};
|
||||
return PtrAndDeleter{p, DeleterFree(call_munmap, bytes)};
|
||||
#elif HWY_OS_WIN
|
||||
const auto call_free = [](void* ptr, size_t) { _aligned_free(ptr); };
|
||||
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
||||
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
|
||||
DeleterFree(call_free, bytes)};
|
||||
#else
|
||||
return PtrAndDeleter{nullptr, DeleterFree(nullptr, 0)};
|
||||
#endif
|
||||
}
|
||||
|
||||
#if GEMMA_BIND && HWY_OS_LINUX
|
||||
|
||||
using Ret = long; // NOLINT(runtime/int)
|
||||
using UL = unsigned long; // NOLINT(runtime/int)
|
||||
static constexpr size_t ULBits = sizeof(UL) * 8;
|
||||
|
||||
// Calling via syscall avoids a dependency on libnuma.
|
||||
struct SyscallWrappers {
|
||||
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
|
||||
unsigned flags) {
|
||||
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
|
||||
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
|
||||
};
|
||||
|
||||
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
|
||||
int* status, int flags) {
|
||||
MaybeCheckInitialized(pages, count * sizeof(void*));
|
||||
MaybeCheckInitialized(nodes, count * sizeof(int));
|
||||
MaybeCheckInitialized(status, count * sizeof(int));
|
||||
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
|
||||
}
|
||||
|
||||
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
|
||||
unsigned flags) {
|
||||
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
|
||||
}
|
||||
};
|
||||
|
||||
// Returns the number of pages that are currently busy (hence not yet moved),
|
||||
// and warns if there are any other reasons for not moving a page. Note that
|
||||
// `move_pages` can return 0 regardless of whether all pages were moved.
|
||||
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
||||
const int* status) {
|
||||
size_t num_busy = 0;
|
||||
for (size_t i = 0; i < num_pages; ++i) {
|
||||
if (status[i] == -EBUSY) {
|
||||
++num_busy;
|
||||
} else if (status[i] != static_cast<int>(node)) {
|
||||
static std::atomic_flag first = ATOMIC_FLAG_INIT;
|
||||
if (!first.test_and_set()) {
|
||||
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
|
||||
status[i], i, pages[i], node, errno);
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_busy;
|
||||
}
|
||||
|
||||
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) {
|
||||
HWY_DASSERT(should_bind_);
|
||||
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
||||
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
// Ensure the requested `node` is allowed.
|
||||
UL nodes[kMaxNodes / 64] = {0};
|
||||
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
|
||||
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
|
||||
nullptr, flags) == 0);
|
||||
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
|
||||
}
|
||||
|
||||
// Avoid mbind because it does not report why it failed, which is most likely
|
||||
// because pages are busy, in which case we want to know which.
|
||||
|
||||
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
|
||||
const unsigned flags = 2; // MPOL_MF_MOVE
|
||||
HWY_ASSERT(bytes % quantum_bytes_ == 0);
|
||||
const size_t num_pages = bytes / quantum_bytes_;
|
||||
std::vector<void*> pages;
|
||||
pages.reserve(num_pages);
|
||||
for (size_t i = 0; i < num_pages; ++i) {
|
||||
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
|
||||
// Ensure the page is faulted in to prevent `move_pages` from failing,
|
||||
// because freshly allocated pages may be mapped to a shared 'zero page'.
|
||||
hwy::ZeroBytes(pages.back(), 8);
|
||||
}
|
||||
std::vector<int> nodes(num_pages, node);
|
||||
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
|
||||
|
||||
Ret ret = SyscallWrappers::move_pages(
|
||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||
if (ret < 0) {
|
||||
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
|
||||
bytes, node, errno, status[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t num_busy =
|
||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||
if (HWY_UNLIKELY(num_busy != 0)) {
|
||||
// Trying again is usually enough to succeed.
|
||||
hwy::NanoSleep(5000);
|
||||
(void)SyscallWrappers::move_pages(
|
||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||
const size_t still_busy =
|
||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||
if (HWY_UNLIKELY(still_busy != 0)) {
|
||||
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
|
||||
still_busy, num_busy);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
bool Allocator::BindMemory(void*, size_t, size_t) { return false; }
|
||||
#endif // GEMMA_BIND && HWY_OS_LINUX
|
||||
|
||||
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
|
||||
line_bytes_ = DetectLineBytes();
|
||||
vector_bytes_ = hwy::VectorBytes();
|
||||
|
|
@ -428,7 +201,7 @@ size_t Allocator2::FreeMiB() const {
|
|||
#endif
|
||||
}
|
||||
|
||||
Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
|
||||
AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const {
|
||||
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
|
||||
// defends against 2K aliasing.
|
||||
if (!should_bind_) {
|
||||
|
|
@ -444,9 +217,10 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
|
|||
// alignment scheme in aligned_allocator.cc and does not work for
|
||||
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
|
||||
// pointer in our own deleter.
|
||||
return PtrAndDeleter{p.release(), DeleterFunc2([](void* ptr) {
|
||||
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
|
||||
})};
|
||||
return AlignedPtr2<uint8_t[]>(p.release(), DeleterFunc2([](void* ptr) {
|
||||
hwy::FreeAlignedBytes(ptr, nullptr,
|
||||
nullptr);
|
||||
}));
|
||||
}
|
||||
|
||||
// Binding, or large vector/cache line size: use platform-specific allocator.
|
||||
|
|
@ -460,20 +234,126 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
|
|||
const int fd = -1;
|
||||
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
|
||||
if (p == MAP_FAILED) p = nullptr;
|
||||
return PtrAndDeleter{p, DeleterFunc2([bytes](void* ptr) {
|
||||
return AlignedPtr2<uint8_t[]>(static_cast<uint8_t*>(p),
|
||||
DeleterFunc2([bytes](void* ptr) {
|
||||
HWY_ASSERT(munmap(ptr, bytes) == 0);
|
||||
})};
|
||||
}));
|
||||
#elif HWY_OS_WIN
|
||||
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
|
||||
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
|
||||
DeleterFunc2([](void* ptr) { _aligned_free(ptr); })};
|
||||
return AlignedPtr2<uint8_t[]>(
|
||||
static_cast<uint8_t*>(_aligned_malloc(bytes, alignment)),
|
||||
DeleterFunc2([](void* ptr) { _aligned_free(ptr); }));
|
||||
#else
|
||||
return PtrAndDeleter{nullptr, DeleterFunc2()};
|
||||
return AlignedPtr2<uint8_t[]>(nullptr, DeleterFunc2());
|
||||
#endif
|
||||
}
|
||||
|
||||
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
||||
return Allocator::BindMemory(ptr, bytes, node);
|
||||
#if GEMMA_BIND && HWY_OS_LINUX
|
||||
|
||||
using Ret = long; // NOLINT(runtime/int)
|
||||
using UL = unsigned long; // NOLINT(runtime/int)
|
||||
static constexpr size_t ULBits = sizeof(UL) * 8;
|
||||
|
||||
// Calling via syscall avoids a dependency on libnuma.
|
||||
struct SyscallWrappers {
|
||||
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
|
||||
unsigned flags) {
|
||||
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
|
||||
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
|
||||
};
|
||||
|
||||
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
|
||||
int* status, int flags) {
|
||||
MaybeCheckInitialized(pages, count * sizeof(void*));
|
||||
MaybeCheckInitialized(nodes, count * sizeof(int));
|
||||
MaybeCheckInitialized(status, count * sizeof(int));
|
||||
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
|
||||
}
|
||||
|
||||
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
|
||||
unsigned flags) {
|
||||
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
|
||||
}
|
||||
};
|
||||
|
||||
// Returns the number of pages that are currently busy (hence not yet moved),
|
||||
// and warns if there are any other reasons for not moving a page. Note that
|
||||
// `move_pages` can return 0 regardless of whether all pages were moved.
|
||||
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
|
||||
const int* status) {
|
||||
size_t num_busy = 0;
|
||||
for (size_t i = 0; i < num_pages; ++i) {
|
||||
if (status[i] == -EBUSY) {
|
||||
++num_busy;
|
||||
} else if (status[i] != static_cast<int>(node)) {
|
||||
static std::atomic_flag first = ATOMIC_FLAG_INIT;
|
||||
if (!first.test_and_set()) {
|
||||
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
|
||||
status[i], i, pages[i], node, errno);
|
||||
}
|
||||
}
|
||||
}
|
||||
return num_busy;
|
||||
}
|
||||
|
||||
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
|
||||
HWY_DASSERT(should_bind_);
|
||||
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
|
||||
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
// Ensure the requested `node` is allowed.
|
||||
UL nodes[kMaxNodes / 64] = {0};
|
||||
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
|
||||
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
|
||||
nullptr, flags) == 0);
|
||||
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
|
||||
}
|
||||
|
||||
// Avoid mbind because it does not report why it failed, which is most likely
|
||||
// because pages are busy, in which case we want to know which.
|
||||
|
||||
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
|
||||
const unsigned flags = 2; // MPOL_MF_MOVE
|
||||
HWY_ASSERT(bytes % quantum_bytes_ == 0);
|
||||
const size_t num_pages = bytes / quantum_bytes_;
|
||||
std::vector<void*> pages;
|
||||
pages.reserve(num_pages);
|
||||
for (size_t i = 0; i < num_pages; ++i) {
|
||||
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
|
||||
// Ensure the page is faulted in to prevent `move_pages` from failing,
|
||||
// because freshly allocated pages may be mapped to a shared 'zero page'.
|
||||
hwy::ZeroBytes(pages.back(), 8);
|
||||
}
|
||||
std::vector<int> nodes(num_pages, node);
|
||||
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
|
||||
|
||||
Ret ret = SyscallWrappers::move_pages(
|
||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||
if (ret < 0) {
|
||||
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
|
||||
bytes, node, errno, status[0]);
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t num_busy =
|
||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||
if (HWY_UNLIKELY(num_busy != 0)) {
|
||||
// Trying again is usually enough to succeed.
|
||||
hwy::NanoSleep(5000);
|
||||
(void)SyscallWrappers::move_pages(
|
||||
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
|
||||
const size_t still_busy =
|
||||
CountBusyPages(num_pages, node, pages.data(), status.data());
|
||||
if (HWY_UNLIKELY(still_busy != 0)) {
|
||||
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
|
||||
still_busy, num_busy);
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
#else
|
||||
bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; }
|
||||
#endif // GEMMA_BIND && HWY_OS_LINUX
|
||||
|
||||
} // namespace gcpp
|
||||
|
|
|
|||
325
util/allocator.h
325
util/allocator.h
|
|
@ -30,307 +30,8 @@
|
|||
#include "hwy/base.h"
|
||||
// IWYU pragma: end_exports
|
||||
|
||||
#include "hwy/aligned_allocator.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The
|
||||
// `bytes` argument is required for the latter.
|
||||
using FreeFunc = void (*)(void* mem, size_t bytes);
|
||||
|
||||
// Custom deleter for std::unique_ptr that calls `FreeFunc`. T is POD.
|
||||
class DeleterFree {
|
||||
public:
|
||||
// `MatStorageT` requires this to be default-constructible.
|
||||
DeleterFree() : free_func_(nullptr), bytes_(0) {}
|
||||
DeleterFree(FreeFunc free_func, size_t bytes)
|
||||
: free_func_(free_func), bytes_(bytes) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T* p) const {
|
||||
free_func_(p, bytes_);
|
||||
}
|
||||
|
||||
private:
|
||||
FreeFunc free_func_;
|
||||
size_t bytes_;
|
||||
};
|
||||
|
||||
// Wrapper that also calls the destructor for non-POD T.
|
||||
class DeleterDtor {
|
||||
public:
|
||||
DeleterDtor() {}
|
||||
DeleterDtor(size_t num, DeleterFree free) : num_(num), free_(free) {}
|
||||
|
||||
template <typename T>
|
||||
void operator()(T* p) const {
|
||||
for (size_t i = 0; i < num_; ++i) {
|
||||
p[i].~T();
|
||||
}
|
||||
free_(p);
|
||||
}
|
||||
|
||||
private:
|
||||
size_t num_; // not the same as free_.bytes_ / sizeof(T)!
|
||||
DeleterFree free_;
|
||||
};
|
||||
|
||||
// Unique (move-only) pointer to an aligned array of POD T.
|
||||
template <typename T>
|
||||
using AlignedPtr = std::unique_ptr<T[], DeleterFree>;
|
||||
// Unique (move-only) pointer to an aligned array of non-POD T.
|
||||
template <typename T>
|
||||
using AlignedClassPtr = std::unique_ptr<T[], DeleterDtor>;
|
||||
|
||||
// Both allocation, binding, and row accessors depend on the sizes of memory
|
||||
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
|
||||
// use `Monostate` (static members).
|
||||
class Allocator {
|
||||
public:
|
||||
// Must be called at least once before any other function. Not thread-safe,
|
||||
// hence only call this from the main thread.
|
||||
// TODO: remove enable_bind once Gemma tensors support binding.
|
||||
static void Init(const BoundedTopology& topology, bool enable_bind = false);
|
||||
|
||||
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
|
||||
// ranges such that there will be no false sharing.
|
||||
static size_t LineBytes();
|
||||
// Bytes per full vector. Used to compute loop steps.
|
||||
static size_t VectorBytes();
|
||||
// Work granularity that avoids false sharing and partial vectors.
|
||||
static size_t StepBytes(); // = HWY_MAX(LineBytes(), VectorBytes())
|
||||
// Granularity like `StepBytes()`, but when NUMA may be involved.
|
||||
static size_t QuantumBytes();
|
||||
// Upper bound on `QuantumBytes()`, for stack allocations.
|
||||
static constexpr size_t MaxQuantumBytes() { return 4096; }
|
||||
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
|
||||
|
||||
// L1 and L2 are typically per core.
|
||||
static size_t L1Bytes();
|
||||
static size_t L2Bytes();
|
||||
// Clusters often share an L3. We return the total size per package.
|
||||
static size_t L3Bytes();
|
||||
|
||||
// Returns pointer aligned to `QuantumBytes()`.
|
||||
template <typename T>
|
||||
static AlignedPtr<T> Alloc(size_t num) {
|
||||
constexpr size_t kSize = sizeof(T);
|
||||
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
|
||||
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
|
||||
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
|
||||
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
|
||||
// Fail if the `bytes = num * kSize` computation overflowed.
|
||||
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
|
||||
if (check != num) return AlignedPtr<T>();
|
||||
|
||||
PtrAndDeleter pd = AllocBytes(bytes);
|
||||
return AlignedPtr<T>(static_cast<T*>(pd.p), pd.deleter);
|
||||
}
|
||||
|
||||
// Same as Alloc, but calls constructor(s) with `args`.
|
||||
template <typename T, class... Args>
|
||||
static AlignedClassPtr<T> AllocClasses(size_t num, Args&&... args) {
|
||||
constexpr size_t kSize = sizeof(T);
|
||||
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
|
||||
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
|
||||
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
|
||||
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
|
||||
// Fail if the `bytes = num * kSize` computation overflowed.
|
||||
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
|
||||
if (check != num) return AlignedClassPtr<T>();
|
||||
|
||||
PtrAndDeleter pd = AllocBytes(bytes);
|
||||
T* p = static_cast<T*>(pd.p);
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
new (p + i) T(std::forward<Args>(args)...);
|
||||
}
|
||||
return AlignedClassPtr<T>(p, DeleterDtor(num, pd.deleter));
|
||||
}
|
||||
|
||||
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
||||
// control over memory placement and multiple packages and NUMA nodes.
|
||||
static bool ShouldBind();
|
||||
|
||||
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
|
||||
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
|
||||
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
|
||||
// `p` and `bytes` must be multiples of `QuantumBytes()`.
|
||||
static bool BindMemory(void* p, size_t bytes, size_t node);
|
||||
|
||||
private:
|
||||
// Type-erased so this can be implemented in allocator.cc.
|
||||
struct PtrAndDeleter {
|
||||
void* p;
|
||||
DeleterFree deleter;
|
||||
};
|
||||
static PtrAndDeleter AllocBytes(size_t bytes);
|
||||
};
|
||||
|
||||
// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets"
|
||||
// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is
|
||||
// typically 4KiB. To avoid remote accesses, we would thus pad each row to that,
|
||||
// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able
|
||||
// to prevent that by pulling rows forward by a cyclic offset, which is still a
|
||||
// multiple of the cache line size. This requires an additional
|
||||
// `Allocator::QuantumBytes()` of padding after also rounding up to that.
|
||||
template <typename T>
|
||||
constexpr size_t StrideForCyclicOffsets(size_t cols) {
|
||||
const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T);
|
||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||
}
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||
// the memory.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() = default;
|
||||
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
||||
// we default to tightly packed rows (`stride = cols`).
|
||||
// WARNING: not all call sites support `stride` != cols.
|
||||
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
||||
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
|
||||
if (stride == 0) {
|
||||
stride_ = extents_.cols;
|
||||
} else {
|
||||
HWY_ASSERT(stride >= extents_.cols);
|
||||
stride_ = stride;
|
||||
}
|
||||
// Allow binding the entire matrix.
|
||||
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
|
||||
Allocator::QuantumBytes() / sizeof(T));
|
||||
mem_ = Allocator::Alloc<T>(padded);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return extents_.rows; }
|
||||
size_t Cols() const { return extents_.cols; }
|
||||
size_t Stride() const { return stride_; }
|
||||
Extents2D Extents() const { return extents_; }
|
||||
|
||||
// Returns the given row vector of length `Cols()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * stride_;
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * stride_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
// TODO: remove once we only use Mat.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
AlignedPtr<T> mem_;
|
||||
Extents2D extents_;
|
||||
size_t stride_;
|
||||
};
|
||||
|
||||
// Returns `num` rounded up to an odd number of cache lines. This is used to
|
||||
// compute strides. An odd number of cache lines prevents 2K aliasing and is
|
||||
// coprime with the cache associativity, which reduces conflict misses.
|
||||
template <typename T>
|
||||
static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
|
||||
HWY_DASSERT(line_bytes >= 32);
|
||||
HWY_DASSERT(line_bytes % sizeof(T) == 0);
|
||||
const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes);
|
||||
const size_t padded_num = (lines | 1) * line_bytes / sizeof(T);
|
||||
HWY_DASSERT(padded_num >= num);
|
||||
return padded_num;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AllocateAlignedRows(Extents2D extents) {
|
||||
return RowVectorBatch<T>(extents, StrideForCyclicOffsets<T>(extents.cols));
|
||||
}
|
||||
|
||||
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
|
||||
// it is always float and does not support compressed T, but does support an
|
||||
// arbitrary stride >= cols.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr() = default; // for `MMPtrs`.
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
|
||||
: row0_(row0),
|
||||
stride_(stride),
|
||||
step_(static_cast<uint32_t>(Allocator::StepBytes())),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
row_mask_(Allocator::QuantumSteps() - 1) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
HWY_DASSERT(row_mask_ != ~size_t{0});
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
if (stride < StrideForCyclicOffsets<T>(cols)) {
|
||||
static bool once;
|
||||
if (!once) {
|
||||
once = true;
|
||||
HWY_WARN(
|
||||
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
|
||||
"T=%zu; this forces us to disable cyclic offsets.",
|
||||
stride, cols, sizeof(T));
|
||||
}
|
||||
row_mask_ = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const {
|
||||
// How much of the previous row's padding to consume.
|
||||
const size_t pad_bytes = (r & row_mask_) * step_;
|
||||
HWY_DASSERT(pad_bytes < Allocator::QuantumBytes());
|
||||
return row0_ + stride_ * r - pad_bytes;
|
||||
}
|
||||
size_t Cols() const { return cols_; }
|
||||
|
||||
size_t Stride() const { return stride_; }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
// The caller might not have padded enough, so disable the padding in Row().
|
||||
// Rows will now be exactly `stride` elements apart. This is used when
|
||||
// writing to the KV cache via MatMul.
|
||||
row_mask_ = 0;
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||
HWY_DASSERT(c < cols_);
|
||||
HWY_DASSERT(cols <= cols_ - c);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_);
|
||||
}
|
||||
|
||||
private:
|
||||
T* HWY_RESTRICT row0_;
|
||||
size_t stride_;
|
||||
uint32_t step_; // Copy from Allocator::LineBytes() to improve locality.
|
||||
uint32_t cols_;
|
||||
size_t row_mask_;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
using RowPtrBF = RowPtr<BF16>;
|
||||
using RowPtrF = RowPtr<float>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
|
||||
// For C argument to MatMul.
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
|
||||
return RowPtr<T>(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride());
|
||||
}
|
||||
|
||||
// Custom deleter for types without a dtor, but where the deallocation requires
|
||||
// state, e.g. a lambda with *by-value* capture.
|
||||
class DeleterFunc2 {
|
||||
|
|
@ -420,15 +121,22 @@ class Allocator2 {
|
|||
size_t TotalMiB() const { return total_mib_; }
|
||||
size_t FreeMiB() const;
|
||||
|
||||
// Returns pointer aligned to `QuantumBytes()`.
|
||||
// Returns byte pointer aligned to `QuantumBytes()`, without calling
|
||||
// constructors nor destructors on deletion. Type-erased so this can be
|
||||
// implemented in `allocator.cc` and called by `MatOwner`.
|
||||
AlignedPtr2<uint8_t[]> AllocBytes(size_t bytes) const;
|
||||
|
||||
// Returns pointer aligned to `QuantumBytes()`, without calling constructors
|
||||
// nor destructors on deletion.
|
||||
template <typename T>
|
||||
AlignedPtr2<T[]> Alloc(size_t num) const {
|
||||
const size_t bytes = num * sizeof(T);
|
||||
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
||||
HWY_ASSERT(bytes / sizeof(T) == num);
|
||||
|
||||
PtrAndDeleter pd = AllocBytes(bytes);
|
||||
return AlignedPtr2<T[]>(static_cast<T*>(pd.p), pd.deleter);
|
||||
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
|
||||
return AlignedPtr2<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()),
|
||||
p8.get_deleter());
|
||||
}
|
||||
|
||||
// Same as Alloc, but calls constructor(s) with `args` and the deleter will
|
||||
|
|
@ -439,12 +147,12 @@ class Allocator2 {
|
|||
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
|
||||
HWY_ASSERT(bytes / sizeof(T) == num);
|
||||
|
||||
PtrAndDeleter pd = AllocBytes(bytes);
|
||||
T* p = static_cast<T*>(pd.p);
|
||||
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
|
||||
T* p = HWY_RCAST_ALIGNED(T*, p8.release());
|
||||
for (size_t i = 0; i < num; ++i) {
|
||||
new (p + i) T(std::forward<Args>(args)...);
|
||||
}
|
||||
return AlignedClassPtr2<T>(p, DeleterDtor2(num, pd.deleter));
|
||||
return AlignedClassPtr2<T>(p, DeleterDtor2(num, p8.get_deleter()));
|
||||
}
|
||||
|
||||
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
|
||||
|
|
@ -458,13 +166,6 @@ class Allocator2 {
|
|||
bool BindMemory(void* p, size_t bytes, size_t node) const;
|
||||
|
||||
private:
|
||||
// Type-erased so this can be implemented in allocator.cc.
|
||||
struct PtrAndDeleter {
|
||||
void* p;
|
||||
DeleterFunc2 deleter;
|
||||
};
|
||||
PtrAndDeleter AllocBytes(size_t bytes) const;
|
||||
|
||||
size_t line_bytes_;
|
||||
size_t vector_bytes_;
|
||||
size_t step_bytes_;
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@
|
|||
#include <algorithm> // std::transform
|
||||
#include <string>
|
||||
|
||||
#include "compression/io.h"
|
||||
#include "compression/io.h" // Path
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "hwy/base.h" // HWY_ABORT
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,100 @@
|
|||
// Copyright 2024 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// https://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "util/mat.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/per_target.h" // VectorBytes
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to) {
|
||||
PROFILER_FUNC;
|
||||
HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols());
|
||||
HWY_ASSERT(to.GetType() == from.GetType());
|
||||
if (to.IsPacked() && from.IsPacked()) {
|
||||
HWY_ASSERT(to.PackedBytes() == from.PackedBytes());
|
||||
hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes());
|
||||
return;
|
||||
}
|
||||
const size_t row_bytes = to.Cols() * to.ElementBytes();
|
||||
for (size_t r = 0; r < to.Rows(); ++r) {
|
||||
const uint8_t* from_row = from.RowT<uint8_t>(r);
|
||||
uint8_t* to_row = to.RowT<uint8_t>(r);
|
||||
hwy::CopyBytes(from_row, to_row, row_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
void ZeroInit(MatPtr& mat) {
|
||||
PROFILER_FUNC;
|
||||
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
|
||||
if (mat.IsPacked()) {
|
||||
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
|
||||
return;
|
||||
}
|
||||
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
|
||||
for (size_t r = 0; r < mat.Rows(); ++r) {
|
||||
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes);
|
||||
}
|
||||
}
|
||||
|
||||
// Returns `num` rounded up to an odd number of cache lines. This would also
|
||||
// prevent 4K aliasing and is coprime with the cache associativity, which
|
||||
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
|
||||
static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
|
||||
size_t element_bytes) {
|
||||
HWY_DASSERT(line_bytes >= 32);
|
||||
HWY_DASSERT(line_bytes % element_bytes == 0);
|
||||
const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes);
|
||||
const size_t padded_num = (lines | 1) * line_bytes / element_bytes;
|
||||
HWY_DASSERT(padded_num >= num);
|
||||
return padded_num;
|
||||
}
|
||||
|
||||
static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
|
||||
MatPadding padding) {
|
||||
switch (padding) {
|
||||
case MatPadding::kPacked:
|
||||
default:
|
||||
return mat.Cols();
|
||||
case MatPadding::kOdd:
|
||||
return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(),
|
||||
mat.ElementBytes());
|
||||
case MatPadding::kCyclic:
|
||||
return StrideForCyclicOffsets(
|
||||
mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes());
|
||||
}
|
||||
}
|
||||
|
||||
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
|
||||
const Allocator2& allocator = ThreadingContext2::Get().allocator;
|
||||
const size_t stride = Stride(allocator, mat, padding);
|
||||
const size_t num = mat.Rows() * stride;
|
||||
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
|
||||
// might not be enough, hence add extra. `MatT` is at least one byte, which
|
||||
// is half of BF16, hence adding `VectorBytes` *elements* is enough.
|
||||
const size_t bytes = (num + hwy::VectorBytes()) * mat.ElementBytes();
|
||||
// Allow binding the entire matrix.
|
||||
const size_t padded_bytes =
|
||||
hwy::RoundUpTo(bytes, allocator.QuantumBytes() / mat.ElementBytes());
|
||||
storage_ = allocator.AllocBytes(padded_bytes);
|
||||
mat.SetPtr(storage_.get(), stride);
|
||||
}
|
||||
} // namespace gcpp
|
||||
|
|
@ -0,0 +1,532 @@
|
|||
// Copyright 2023 Google LLC
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
// Tensor metadata and in-memory representation.
|
||||
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
|
||||
#include <random>
|
||||
#include <string>
|
||||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "compression/fields.h"
|
||||
#include "compression/shared.h" // Type
|
||||
#include "gemma/tensor_index.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h" // Extents2D
|
||||
// IWYU pragma: end_exports
|
||||
#include "hwy/base.h"
|
||||
|
||||
namespace gcpp {
|
||||
|
||||
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
|
||||
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
|
||||
// to store hetereogeneous tensor references in a vector.
|
||||
//
|
||||
// Copyable, (de)serializable via `fields.h` for `model_store.h`.
|
||||
class MatPtr : public IFields {
|
||||
public:
|
||||
MatPtr() = default;
|
||||
// `name`: see `SetName`. Note that `stride` is initially `cols` and only
|
||||
// differs after deserializing, or calling `SetPtr`.
|
||||
MatPtr(const char* name, Type type, Extents2D extents)
|
||||
: rows_(static_cast<uint32_t>(extents.rows)),
|
||||
cols_(static_cast<uint32_t>(extents.cols)) {
|
||||
SetName(name);
|
||||
SetType(type);
|
||||
SetPtr(nullptr, cols_);
|
||||
}
|
||||
|
||||
// Copying allowed because the metadata is small.
|
||||
MatPtr(const MatPtr& other) = default;
|
||||
MatPtr& operator=(const MatPtr& other) = default;
|
||||
|
||||
virtual ~MatPtr() = default;
|
||||
|
||||
// Only for use by ctor, `AllocateFor` and 'loading' memory-mapped tensors.
|
||||
void SetPtr(void* ptr, size_t stride) {
|
||||
HWY_ASSERT(stride >= Cols());
|
||||
ptr_ = ptr;
|
||||
stride_ = static_cast<uint32_t>(stride);
|
||||
|
||||
// NUQ streams must not be padded because that would change the position of
|
||||
// the group tables.
|
||||
if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked());
|
||||
}
|
||||
|
||||
bool HasPtr() const { return ptr_ != nullptr; }
|
||||
|
||||
bool IsPacked() const { return stride_ == cols_; }
|
||||
|
||||
const void* Packed() const {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
return ptr_;
|
||||
}
|
||||
void* Packed() {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
return ptr_;
|
||||
}
|
||||
|
||||
// Returns size in bytes for purposes of copying/initializing or I/O. Must
|
||||
// only be called if `IsPacked`.
|
||||
size_t PackedBytes() const {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
// num_elements_ already includes the NUQ tables.
|
||||
return num_elements_ * element_bytes_;
|
||||
}
|
||||
|
||||
// Works for any kind of padding.
|
||||
template <typename T>
|
||||
T* MutableRowT(size_t row) const {
|
||||
HWY_DASSERT(row < rows_);
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||
}
|
||||
template <typename T>
|
||||
T* RowT(size_t row) {
|
||||
HWY_DASSERT(row < rows_);
|
||||
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
|
||||
}
|
||||
template <typename T>
|
||||
const T* RowT(size_t row) const {
|
||||
HWY_DASSERT(row < rows_);
|
||||
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
|
||||
}
|
||||
|
||||
Type GetType() const { return type_; }
|
||||
void SetType(Type type) {
|
||||
type_ = type;
|
||||
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
|
||||
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
|
||||
}
|
||||
|
||||
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; }
|
||||
size_t Rows() const { return rows_; }
|
||||
size_t Cols() const { return cols_; }
|
||||
Extents2D Extents() const { return Extents2D(rows_, cols_); }
|
||||
|
||||
// Offset by which to advance pointers to the next row.
|
||||
size_t Stride() const { return stride_; }
|
||||
|
||||
// For use by `BlobStore`, `CopyMat` and `ZeroInit`.
|
||||
size_t ElementBytes() const { return element_bytes_; }
|
||||
|
||||
// Decoded elements should be multiplied by this to restore their original
|
||||
// range. This is required because `SfpStream` can only encode a limited range
|
||||
// of magnitudes.
|
||||
float Scale() const { return scale_; }
|
||||
void SetScale(float scale) { scale_ = scale; }
|
||||
|
||||
// Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it
|
||||
// be <= 16 bytes including prefixes/suffixes. The initial name set by the
|
||||
// ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer
|
||||
// suffix, and when loading, we call `SetName` with that.
|
||||
const char* Name() const override { return name_.c_str(); }
|
||||
void SetName(const char* name) {
|
||||
name_ = name;
|
||||
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
|
||||
}
|
||||
|
||||
void VisitFields(IFieldsVisitor& visitor) override {
|
||||
// Order determines the order of serialization and must not change.
|
||||
visitor(name_);
|
||||
visitor(type_);
|
||||
visitor(element_bytes_);
|
||||
visitor(num_elements_);
|
||||
visitor(rows_);
|
||||
visitor(cols_);
|
||||
visitor(scale_);
|
||||
visitor(stride_);
|
||||
}
|
||||
|
||||
protected:
|
||||
// For initializing `num_elements_`: "elements" are how many objects we
|
||||
// actually store in order to represent rows * cols values. For NUQ, this is
|
||||
// greater because it includes additional per-group tables. This is the only
|
||||
// place where we compute this fixup. Note that elements are independent of
|
||||
// padding, which is anyway not supported for NUQ because `compress-inl.h`
|
||||
// assumes a contiguous stream for its group indexing.
|
||||
static size_t ComputeNumElements(Type type, Extents2D extents) {
|
||||
const size_t num_elements = extents.Area();
|
||||
if (type == Type::kNUQ) {
|
||||
// `CompressedArrayElements` is a wrapper function that has the same
|
||||
// effect, but that requires a template argument, not `type`.
|
||||
return NuqStream::PackedEnd(num_elements);
|
||||
}
|
||||
return num_elements;
|
||||
}
|
||||
|
||||
std::string name_; // See `SetName`.
|
||||
Type type_;
|
||||
|
||||
// Most members are u32 because that is the preferred type of fields.h.
|
||||
|
||||
// Bytes per element. This is fully determined by `type_`, but stored here
|
||||
// for convenience and backward compatibility.
|
||||
uint32_t element_bytes_ = 0;
|
||||
// Number of elements to store (including NUQ tables but not padding).
|
||||
// This a function of `type_` and `Extents()` and stored for compatibility.
|
||||
uint32_t num_elements_ = 0;
|
||||
uint32_t rows_ = 0;
|
||||
uint32_t cols_ = 0;
|
||||
float scale_ = 1.0f; // multiplier for each value, for MatMul.
|
||||
|
||||
// Non-owning pointer, must not be freed. The underlying memory must outlive
|
||||
// this object.
|
||||
void* ptr_ = nullptr; // not serialized
|
||||
|
||||
// Offset by which to advance pointers to the next row, >= `cols_`.
|
||||
uint32_t stride_;
|
||||
};
|
||||
|
||||
// Non-type erased version of `MatPtr`. Use this when operating on the values.
|
||||
template <typename MatT>
|
||||
class MatPtrT : public MatPtr {
|
||||
public:
|
||||
// Runtime-specified shape.
|
||||
MatPtrT(const char* name, Extents2D extents)
|
||||
: MatPtr(name, TypeEnum<MatT>(), extents) {}
|
||||
// Take shape from `TensorInfo` to avoid duplicating it in the caller.
|
||||
MatPtrT(const char* name, const TensorInfo* tensor)
|
||||
: MatPtrT<MatT>(name, ExtentsFromInfo(tensor)) {}
|
||||
// Find `TensorInfo` by name in `TensorIndex`.
|
||||
MatPtrT(const char* name, const TensorIndex& tensor_index)
|
||||
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
|
||||
|
||||
// Copying allowed because the metadata is small.
|
||||
MatPtrT(const MatPtr& other) : MatPtr(other) {}
|
||||
MatPtrT& operator=(const MatPtr& other) {
|
||||
MatPtr::operator=(other);
|
||||
return *this;
|
||||
}
|
||||
MatPtrT(const MatPtrT& other) = default;
|
||||
MatPtrT& operator=(const MatPtrT& other) = default;
|
||||
|
||||
// Returns the entire tensor for use by `backprop/*`. Verifies layout is
|
||||
// `kPacked`. Preferably call `Row` instead, which works for either layout.
|
||||
MatT* Packed() {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
return HWY_RCAST_ALIGNED(MatT*, ptr_);
|
||||
}
|
||||
const MatT* Packed() const {
|
||||
HWY_DASSERT_M(IsPacked(), name_.c_str());
|
||||
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
|
||||
}
|
||||
// As `Packed()`, plus checks the scale is 1.0 because callers will ignore it.
|
||||
// This is typically used for `MatMul` bias vectors and norm weights.
|
||||
const MatT* PackedScale1() const {
|
||||
HWY_DASSERT(Scale() == 1.0f);
|
||||
return Packed();
|
||||
}
|
||||
|
||||
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
|
||||
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
|
||||
|
||||
// For `compress-inl.h` functions, which assume contiguous streams and thus
|
||||
// require packed layout.
|
||||
PackedSpan<const MatT> Span() const {
|
||||
HWY_ASSERT(IsPacked());
|
||||
return MakeConstSpan(Row(0), num_elements_);
|
||||
}
|
||||
PackedSpan<MatT> Span() {
|
||||
HWY_ASSERT(IsPacked());
|
||||
return MakeSpan(Row(0), num_elements_);
|
||||
}
|
||||
};
|
||||
|
||||
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
|
||||
// optional `args`.
|
||||
template <class Func, typename... Args>
|
||||
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
|
||||
Args&&... args) {
|
||||
HWY_ASSERT(base != nullptr);
|
||||
if (type == Type::kF32) {
|
||||
return func(dynamic_cast<MatPtrT<float>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (type == Type::kBF16) {
|
||||
return func(dynamic_cast<MatPtrT<BF16>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (type == Type::kSFP) {
|
||||
return func(dynamic_cast<MatPtrT<SfpStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else if (type == Type::kNUQ) {
|
||||
return func(dynamic_cast<MatPtrT<NuqStream>*>(base),
|
||||
std::forward<Args>(args)...);
|
||||
} else {
|
||||
HWY_ABORT("Type %d unknown.", static_cast<int>(type));
|
||||
}
|
||||
}
|
||||
|
||||
void CopyMat(const MatPtr& from, MatPtr& to);
|
||||
void ZeroInit(MatPtr& mat);
|
||||
|
||||
template <typename T>
|
||||
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
|
||||
std::normal_distribution<T> dist(0.0, stddev);
|
||||
for (size_t r = 0; r < x.Rows(); ++r) {
|
||||
T* row = x.Row(r);
|
||||
for (size_t c = 0; c < x.Cols(); ++c) {
|
||||
row[c] = dist(gen);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
|
||||
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB.
|
||||
// To avoid remote accesses, we would thus pad each row to that, which results
|
||||
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
|
||||
// by pulling rows forward by a cyclic offset, which is still a multiple of the
|
||||
// cache line size. This requires an additional `Allocator2::QuantumBytes()` of
|
||||
// padding after also rounding up to that, which considerably increases size for
|
||||
// tall and skinny tensors.
|
||||
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
|
||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||
}
|
||||
// Constexpr version (upper bound) for allocating storage in MatMul.
|
||||
template <typename T>
|
||||
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
|
||||
constexpr size_t quantum = Allocator2::MaxQuantum<T>();
|
||||
return hwy::RoundUpTo(cols, quantum) + quantum;
|
||||
}
|
||||
|
||||
// Our tensors are always row-major. This enum indicates how much (if any)
|
||||
// padding comes after each row.
|
||||
enum class MatPadding {
|
||||
// None, stride == cols. `compress-inl.h` requires this layout because its
|
||||
// interface assumes a continuous 1D array, without awareness of rows. Note
|
||||
// that tensors which were written via `compress-inl.h` (i.e. most in
|
||||
// `BlobStore`) are not padded, which also extends to memory-mapped tensors.
|
||||
// However, `BlobStore` is able to insert padding via row-wise I/O when
|
||||
// reading from disk via `Mode::kRead`.
|
||||
//
|
||||
// `backprop/*` also requires this layout because it indexes directly into
|
||||
// the storage instead of calling `Row()`.
|
||||
kPacked,
|
||||
// Enough to round up to an odd number of cache lines, which can reduce
|
||||
// cache conflict misses or 4K aliasing.
|
||||
kOdd,
|
||||
// Enough to enable the "cyclic offsets" optimization for `MatMul`.
|
||||
kCyclic,
|
||||
};
|
||||
|
||||
// Type-erased, allows storing `AlignedPtr2<T[]>` for various T in the same
|
||||
// vector.
|
||||
class MatOwner {
|
||||
public:
|
||||
MatOwner() = default;
|
||||
// Allow move for `MatStorageT`.
|
||||
MatOwner(MatOwner&&) = default;
|
||||
MatOwner& operator=(MatOwner&&) = default;
|
||||
|
||||
// Allocates the type/extents indicated by `mat` and sets its pointer.
|
||||
void AllocateFor(MatPtr& mat, MatPadding padding);
|
||||
|
||||
private:
|
||||
AlignedPtr2<uint8_t[]> storage_;
|
||||
};
|
||||
|
||||
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
|
||||
// tests to allocate and access tensors of a known type. By contrast, the
|
||||
// heterogeneous model weights are owned by vectors of `MatOwner`.
|
||||
template <typename MatT>
|
||||
class MatStorageT : public MatPtrT<MatT> {
|
||||
public:
|
||||
MatStorageT(const char* name, Extents2D extents, MatPadding padding)
|
||||
: MatPtrT<MatT>(name, extents) {
|
||||
owner_.AllocateFor(*this, padding);
|
||||
}
|
||||
~MatStorageT() = default;
|
||||
|
||||
// Allow move for backprop/activations.
|
||||
MatStorageT(MatStorageT&&) = default;
|
||||
MatStorageT& operator=(MatStorageT&&) = default;
|
||||
|
||||
private:
|
||||
MatOwner owner_;
|
||||
};
|
||||
|
||||
// Helper factory function for use by `backprop/` to avoid specifying the
|
||||
// `MatPadding` argument everywhere.
|
||||
template <typename T>
|
||||
MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
|
||||
return MatStorageT<T>(name, Extents2D(rows, cols), MatPadding::kPacked);
|
||||
}
|
||||
|
||||
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
|
||||
// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets.
|
||||
#pragma pack(push, 1) // power of two size
|
||||
template <typename T>
|
||||
class RowPtr {
|
||||
public:
|
||||
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols,
|
||||
size_t stride)
|
||||
: row0_(row0),
|
||||
stride_(stride),
|
||||
row_mask_(
|
||||
static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
|
||||
cols_(static_cast<uint32_t>(cols)),
|
||||
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())),
|
||||
quantum_bytes_(allocator.QuantumBytes()) {
|
||||
HWY_DASSERT(stride >= cols);
|
||||
HWY_DASSERT(row_mask_ != ~uint32_t{0});
|
||||
if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) {
|
||||
row_mask_ = 0;
|
||||
if constexpr (HWY_IS_DEBUG_BUILD) {
|
||||
static bool once;
|
||||
if (stride != cols && !once) {
|
||||
once = true;
|
||||
HWY_WARN(
|
||||
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
|
||||
"T=%zu; this forces us to disable cyclic offsets.",
|
||||
stride, cols, sizeof(T));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols)
|
||||
: RowPtr(allocator, row0, cols, cols) {}
|
||||
|
||||
T* HWY_RESTRICT Row(size_t r) const {
|
||||
// How much of the previous row's padding to consume.
|
||||
const size_t pad_bytes = (r & row_mask_) * step_bytes_;
|
||||
HWY_DASSERT(pad_bytes < static_cast<size_t>(quantum_bytes_));
|
||||
return row0_ + stride_ * r - pad_bytes;
|
||||
}
|
||||
size_t Cols() const { return static_cast<size_t>(cols_); }
|
||||
|
||||
size_t Stride() const { return stride_; }
|
||||
void SetStride(size_t stride) {
|
||||
HWY_DASSERT(stride >= Cols());
|
||||
stride_ = stride;
|
||||
// The caller might not have padded enough, so disable the padding in Row().
|
||||
// Rows will now be exactly `stride` elements apart. This is used when
|
||||
// writing to the KV cache via MatMul.
|
||||
row_mask_ = 0;
|
||||
}
|
||||
|
||||
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
|
||||
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
|
||||
HWY_DASSERT(c < Cols());
|
||||
HWY_DASSERT(cols <= Cols() - c);
|
||||
return RowPtr<T>(Row(r) + c, cols, stride_, row_mask_, step_bytes_,
|
||||
quantum_bytes_);
|
||||
}
|
||||
|
||||
private:
|
||||
// For `View()`.
|
||||
RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask,
|
||||
uint32_t step_bytes, uint32_t quantum_bytes)
|
||||
: row0_(new_row0),
|
||||
stride_(stride),
|
||||
row_mask_(row_mask),
|
||||
cols_(new_cols),
|
||||
step_bytes_(step_bytes),
|
||||
quantum_bytes_(quantum_bytes) {}
|
||||
|
||||
T* HWY_RESTRICT row0_;
|
||||
size_t stride_;
|
||||
uint32_t row_mask_;
|
||||
uint32_t cols_;
|
||||
uint32_t step_bytes_;
|
||||
uint32_t quantum_bytes_;
|
||||
};
|
||||
#pragma pack(pop)
|
||||
|
||||
using RowPtrBF = RowPtr<BF16>;
|
||||
using RowPtrF = RowPtr<float>;
|
||||
using RowPtrD = RowPtr<double>;
|
||||
|
||||
// Owns dynamically-allocated aligned memory for a batch of row vectors.
|
||||
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
|
||||
// the memory. Unlike `MatPtr`, this lacks metadata.
|
||||
// TODO: replace with `MatStorageT`.
|
||||
template <typename T>
|
||||
class RowVectorBatch {
|
||||
public:
|
||||
// Default ctor for Activations ctor.
|
||||
RowVectorBatch() = default;
|
||||
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
|
||||
// we default to tightly packed rows (`stride = cols`).
|
||||
// WARNING: not all call sites support `stride` != cols.
|
||||
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
|
||||
RowVectorBatch(const Allocator2& allocator, Extents2D extents,
|
||||
size_t stride = 0)
|
||||
: extents_(extents) {
|
||||
if (stride == 0) {
|
||||
stride_ = extents_.cols;
|
||||
} else {
|
||||
HWY_ASSERT(stride >= extents_.cols);
|
||||
stride_ = stride;
|
||||
}
|
||||
// Allow binding the entire matrix.
|
||||
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
|
||||
allocator.QuantumBytes() / sizeof(T));
|
||||
mem_ = allocator.Alloc<T>(padded);
|
||||
}
|
||||
|
||||
// Move-only
|
||||
RowVectorBatch(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
|
||||
RowVectorBatch(RowVectorBatch&&) noexcept = default;
|
||||
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
|
||||
|
||||
size_t BatchSize() const { return extents_.rows; }
|
||||
size_t Cols() const { return extents_.cols; }
|
||||
size_t Stride() const { return stride_; }
|
||||
Extents2D Extents() const { return extents_; }
|
||||
|
||||
// Returns the given row vector of length `Cols()`.
|
||||
T* Batch(size_t batch_idx) {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * stride_;
|
||||
}
|
||||
const T* Batch(size_t batch_idx) const {
|
||||
HWY_DASSERT(batch_idx < BatchSize());
|
||||
return mem_.get() + batch_idx * stride_;
|
||||
}
|
||||
|
||||
// For MatMul or other operations that process the entire batch at once.
|
||||
// TODO: remove once we only use Mat.
|
||||
T* All() { return mem_.get(); }
|
||||
const T* Const() const { return mem_.get(); }
|
||||
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
|
||||
|
||||
private:
|
||||
AlignedPtr2<T[]> mem_;
|
||||
Extents2D extents_;
|
||||
size_t stride_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
RowPtr<T> RowPtrFromBatch(const Allocator2& allocator,
|
||||
RowVectorBatch<T>& row_vectors) {
|
||||
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
|
||||
row_vectors.Stride());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
RowVectorBatch<T> AllocateAlignedRows(const Allocator2& allocator,
|
||||
Extents2D extents) {
|
||||
return RowVectorBatch<T>(
|
||||
allocator, extents,
|
||||
StrideForCyclicOffsets(extents.cols,
|
||||
allocator.QuantumBytes() / sizeof(T)));
|
||||
}
|
||||
|
||||
} // namespace gcpp
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
|
||||
|
|
@ -13,12 +13,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "util/threading.h"
|
||||
#include "util/threading.h" // NOT threading_context..
|
||||
// to ensure there is no deadlock.
|
||||
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm> // std::sort
|
||||
#include <atomic>
|
||||
#include <memory>
|
||||
#include <optional>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -69,7 +71,7 @@ class Pinning {
|
|||
const int bytes_written =
|
||||
snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", pkg_idx, cluster_idx,
|
||||
static_cast<int>(task));
|
||||
HWY_ASSERT(bytes_written < sizeof(buf));
|
||||
HWY_ASSERT(bytes_written < static_cast<int>(sizeof(buf)));
|
||||
hwy::SetThreadName(buf, 0); // does not support varargs
|
||||
|
||||
if (HWY_LIKELY(want_pin_)) {
|
||||
|
|
@ -107,16 +109,16 @@ static Pinning& GetPinning() {
|
|||
return pinning;
|
||||
}
|
||||
|
||||
static PoolPtr MakePool(size_t num_workers,
|
||||
static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers,
|
||||
std::optional<size_t> node = std::nullopt) {
|
||||
// `ThreadPool` expects the number of threads to create, which is one less
|
||||
// than the number of workers, but avoid underflow if zero.
|
||||
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
|
||||
PoolPtr ptr = Allocator::AllocClasses<hwy::ThreadPool>(1, num_threads);
|
||||
PoolPtr ptr = allocator.AllocClasses<hwy::ThreadPool>(1, num_threads);
|
||||
const size_t bytes =
|
||||
hwy::RoundUpTo(sizeof(hwy::ThreadPool), Allocator::QuantumBytes());
|
||||
if (node.has_value() && Allocator::ShouldBind()) {
|
||||
Allocator::BindMemory(ptr.get(), bytes, node.value());
|
||||
hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes());
|
||||
if (node.has_value() && allocator.ShouldBind()) {
|
||||
allocator.BindMemory(ptr.get(), bytes, node.value());
|
||||
}
|
||||
return ptr;
|
||||
}
|
||||
|
|
@ -133,21 +135,21 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
|
|||
return max;
|
||||
}
|
||||
|
||||
NestedPools::NestedPools(const BoundedTopology& topology, size_t max_threads,
|
||||
NestedPools::NestedPools(const BoundedTopology& topology,
|
||||
const Allocator2& allocator, size_t max_threads,
|
||||
Tristate pin) {
|
||||
GetPinning().SetPolicy(pin);
|
||||
packages_.resize(topology.NumPackages());
|
||||
all_packages_ = MakePool(packages_.size());
|
||||
all_packages_ = MakePool(allocator, packages_.size());
|
||||
const size_t max_workers_per_package =
|
||||
DivideMaxAcross(max_threads, packages_.size());
|
||||
// Each worker in all_packages_, including the main thread, will be the
|
||||
// calling thread of an all_clusters[0].Run, and hence pinned to one of the
|
||||
// calling thread of an all_clusters->Run, and hence pinned to one of the
|
||||
// `cluster.lps` if `pin`.
|
||||
all_packages_[0].Run(
|
||||
0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
|
||||
all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
|
||||
HWY_ASSERT(pkg_idx == thread); // each thread has one task
|
||||
packages_[pkg_idx] =
|
||||
Package(topology, pkg_idx, max_workers_per_package);
|
||||
Package(topology, allocator, pkg_idx, max_workers_per_package);
|
||||
});
|
||||
|
||||
all_pinned_ = GetPinning().AllPinned(&pin_string_);
|
||||
|
|
@ -172,28 +174,29 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
|
|||
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
|
||||
}
|
||||
|
||||
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
|
||||
NestedPools::Package::Package(const BoundedTopology& topology,
|
||||
const Allocator2& allocator, size_t pkg_idx,
|
||||
size_t max_workers_per_package) {
|
||||
// Pre-allocate because elements are set concurrently.
|
||||
clusters_.resize(topology.NumClusters(pkg_idx));
|
||||
const size_t max_workers_per_cluster =
|
||||
DivideMaxAcross(max_workers_per_package, clusters_.size());
|
||||
|
||||
all_clusters_ =
|
||||
MakePool(clusters_.size(), topology.GetCluster(pkg_idx, 0).Node());
|
||||
all_clusters_ = MakePool(allocator, clusters_.size(),
|
||||
topology.GetCluster(pkg_idx, 0).Node());
|
||||
// Parallel so we also pin the calling worker in `all_clusters` to
|
||||
// `cluster.lps`.
|
||||
all_clusters_[0].Run(
|
||||
0, all_clusters_[0].NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
||||
all_clusters_->Run(
|
||||
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
|
||||
HWY_ASSERT(cluster_idx == thread); // each thread has one task
|
||||
const BoundedTopology::Cluster& cluster =
|
||||
topology.GetCluster(pkg_idx, cluster_idx);
|
||||
clusters_[cluster_idx] =
|
||||
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster),
|
||||
clusters_[cluster_idx] = MakePool(
|
||||
allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster),
|
||||
cluster.Node());
|
||||
// Pin workers AND the calling thread from `all_clusters`.
|
||||
GetPinning().MaybePin(pkg_idx, cluster_idx, cluster,
|
||||
clusters_[cluster_idx][0]);
|
||||
*clusters_[cluster_idx]);
|
||||
});
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@
|
|||
|
||||
// IWYU pragma: begin_exports
|
||||
#include "util/allocator.h"
|
||||
#include "util/args.h"
|
||||
#include "util/basics.h" // Tristate
|
||||
#include "util/topology.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
|
|
@ -37,7 +38,7 @@ namespace gcpp {
|
|||
|
||||
// Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows
|
||||
// moving because it is a typedef to `std::unique_ptr`.
|
||||
using PoolPtr = AlignedClassPtr<hwy::ThreadPool>;
|
||||
using PoolPtr = AlignedClassPtr2<hwy::ThreadPool>;
|
||||
|
||||
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with
|
||||
// a thread per enabled package; for each of those, one with a thread per
|
||||
|
|
@ -73,10 +74,8 @@ class NestedPools {
|
|||
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
|
||||
// only impose upper bounds on the number of detected packages and clusters
|
||||
// rather than defining the actual number of threads.
|
||||
//
|
||||
// Caller must have called `Allocator::Init` before this.
|
||||
NestedPools(const BoundedTopology& topology, size_t max_threads = 0,
|
||||
Tristate pin = Tristate::kDefault);
|
||||
NestedPools(const BoundedTopology& topology, const Allocator2& allocator,
|
||||
size_t max_threads = 0, Tristate pin = Tristate::kDefault);
|
||||
|
||||
bool AllPinned() const { return all_pinned_; }
|
||||
|
||||
|
|
@ -103,7 +102,7 @@ class NestedPools {
|
|||
}
|
||||
|
||||
size_t NumPackages() const { return packages_.size(); }
|
||||
hwy::ThreadPool& AllPackages() { return all_packages_[0]; }
|
||||
hwy::ThreadPool& AllPackages() { return *all_packages_; }
|
||||
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
|
||||
HWY_DASSERT(pkg_idx < NumPackages());
|
||||
return packages_[pkg_idx].AllClusters();
|
||||
|
|
@ -149,36 +148,36 @@ class NestedPools {
|
|||
class Package {
|
||||
public:
|
||||
Package() = default; // for vector
|
||||
Package(const BoundedTopology& topology, size_t pkg_idx,
|
||||
size_t max_workers_per_package);
|
||||
Package(const BoundedTopology& topology, const Allocator2& allocator,
|
||||
size_t pkg_idx, size_t max_workers_per_package);
|
||||
|
||||
size_t NumClusters() const { return clusters_.size(); }
|
||||
size_t MaxWorkersPerCluster() const {
|
||||
size_t max_workers_per_cluster = 0;
|
||||
for (const PoolPtr& cluster : clusters_) {
|
||||
max_workers_per_cluster =
|
||||
HWY_MAX(max_workers_per_cluster, cluster[0].NumWorkers());
|
||||
HWY_MAX(max_workers_per_cluster, cluster->NumWorkers());
|
||||
}
|
||||
return max_workers_per_cluster;
|
||||
}
|
||||
size_t TotalWorkers() const {
|
||||
size_t total_workers = 0;
|
||||
for (const PoolPtr& cluster : clusters_) {
|
||||
total_workers += cluster[0].NumWorkers();
|
||||
total_workers += cluster->NumWorkers();
|
||||
}
|
||||
return total_workers;
|
||||
}
|
||||
|
||||
hwy::ThreadPool& AllClusters() { return all_clusters_[0]; }
|
||||
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
|
||||
hwy::ThreadPool& Cluster(size_t cluster_idx) {
|
||||
HWY_DASSERT(cluster_idx < clusters_.size());
|
||||
return clusters_[cluster_idx][0];
|
||||
return *clusters_[cluster_idx];
|
||||
}
|
||||
|
||||
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
||||
all_clusters_[0].SetWaitMode(wait_mode);
|
||||
all_clusters_->SetWaitMode(wait_mode);
|
||||
for (PoolPtr& cluster : clusters_) {
|
||||
cluster[0].SetWaitMode(wait_mode);
|
||||
cluster->SetWaitMode(wait_mode);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -188,7 +187,7 @@ class NestedPools {
|
|||
}; // Package
|
||||
|
||||
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
|
||||
all_packages_[0].SetWaitMode(wait_mode);
|
||||
all_packages_->SetWaitMode(wait_mode);
|
||||
for (Package& package : packages_) {
|
||||
package.SetWaitMode(wait_mode);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,6 +33,13 @@ static std::mutex s_ctx_mutex;
|
|||
s_ctx_mutex.unlock();
|
||||
}
|
||||
|
||||
/*static*/ bool ThreadingContext2::IsInitialized() {
|
||||
s_ctx_mutex.lock();
|
||||
const bool initialized = !!s_ctx;
|
||||
s_ctx_mutex.unlock();
|
||||
return initialized;
|
||||
}
|
||||
|
||||
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
|
||||
// We do not bother with double-checked locking because it requires an
|
||||
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,
|
||||
|
|
|
|||
|
|
@ -98,6 +98,10 @@ class ThreadingContext2 {
|
|||
// is expected to be called early in the program, before threading starts.
|
||||
static void SetArgs(const ThreadingArgs& args);
|
||||
|
||||
// Returns whether `Get()` has already been called, typically used to avoid
|
||||
// calling `SetArgs` after that, because it would assert.
|
||||
static bool IsInitialized();
|
||||
|
||||
// Returns a reference to the singleton after initializing it if necessary.
|
||||
// When initializing, uses the args passed to `SetArgs`, or defaults.
|
||||
//
|
||||
|
|
|
|||
|
|
@ -13,8 +13,6 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include "util/threading.h"
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
|
|
@ -22,9 +20,9 @@
|
|||
|
||||
#include "gmock/gmock.h"
|
||||
#include "gtest/gtest.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
#include "hwy/aligned_allocator.h"
|
||||
#include "util/threading_context.h"
|
||||
#include "hwy/aligned_allocator.h" // Span
|
||||
#include "hwy/auto_tune.h"
|
||||
#include "hwy/base.h" // HWY_ASSERT
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
|
|
@ -385,9 +383,7 @@ TEST(ThreadingTest, BenchJoin) {
|
|||
}
|
||||
};
|
||||
|
||||
BoundedTopology topology;
|
||||
Allocator::Init(topology, true);
|
||||
NestedPools pools(topology);
|
||||
NestedPools& pools = ThreadingContext2::Get().pools;
|
||||
// Use last package because the main thread has been pinned to it.
|
||||
const size_t pkg_idx = pools.NumPackages() - 1;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue