diff --git a/BUILD.bazel b/BUILD.bazel index f4aed0d..a5f01e7 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -19,7 +19,10 @@ license( # Dual-licensed Apache 2 and 3-clause BSD. licenses(["notice"]) -exports_files(["LICENSE"]) +exports_files([ + "LICENSE", + ".github/workflows/build.yml", +]) cc_library( name = "basics", @@ -29,6 +32,16 @@ cc_library( ], ) +cc_library( + name = "args", + hdrs = ["util/args.h"], + deps = [ + ":basics", + "//compression:io", # Path + "@highway//:hwy", + ], +) + # Split from :threading to break a circular dependency with :allocator. cc_library( name = "topology", @@ -59,6 +72,7 @@ cc_library( hdrs = ["util/threading.h"], deps = [ ":allocator", + ":args", ":basics", ":topology", # Placeholder for container detection, do not remove @@ -68,14 +82,26 @@ cc_library( ], ) +cc_library( + name = "threading_context", + srcs = ["util/threading_context.cc"], + hdrs = ["util/threading_context.h"], + deps = [ + ":allocator", + ":args", + ":basics", + ":threading", + ":topology", + ], +) + cc_test( name = "threading_test", srcs = ["util/threading_test.cc"], deps = [ - ":allocator", ":basics", - ":threading", - "@googletest//:gtest_main", + ":threading_context", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:auto_tune", "@highway//:hwy", "@highway//:hwy_test_util", @@ -97,6 +123,65 @@ cc_library( ], ) +cc_library( + name = "common", + srcs = [ + "gemma/common.cc", + "gemma/configs.cc", + "gemma/tensor_index.cc", + ], + hdrs = [ + "gemma/common.h", + "gemma/configs.h", + "gemma/tensor_index.h", + ], + deps = [ + ":basics", + "//compression:fields", + "//compression:sfp", + "@highway//:hwy", # base.h + ], +) + +cc_test( + name = "configs_test", + srcs = ["gemma/configs_test.cc"], + deps = [ + ":common", + "@googletest//:gtest_main", # buildcleaner: keep + "@highway//:hwy", + ], +) + +cc_test( + name = "tensor_index_test", + srcs = ["gemma/tensor_index_test.cc"], + deps = [ + ":basics", + ":common", + ":weights", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "@highway//:hwy", + ], +) + +cc_library( + name = "mat", + srcs = ["util/mat.cc"], + hdrs = ["util/mat.h"], + deps = [ + ":allocator", + ":basics", + ":common", + ":threading_context", + "//compression:fields", + "//compression:sfp", + "@highway//:hwy", + "@highway//:profiler", + ], +) + # For building all tests in one command, so we can test several. test_suite( name = "ops_tests", @@ -123,8 +208,9 @@ cc_library( deps = [ ":allocator", ":basics", + ":mat", ":threading", - ":topology", + ":threading_context", "//compression:compress", "@highway//:algo", "@highway//:bit_set", @@ -148,10 +234,9 @@ cc_test( tags = ["ops_tests"], deps = [ ":allocator", - ":app", ":ops", ":test_util", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", "//compression:test_util", @@ -174,13 +259,13 @@ cc_test( tags = ["ops_tests"], deps = [ ":allocator", - ":app", + ":basics", ":common", + ":mat", ":ops", ":test_util", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep - "//compression:compress", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", #buildcleaner: keep @@ -196,6 +281,7 @@ cc_test( # for test_suite. tags = ["ops_tests"], deps = [ + ":mat", ":ops", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", @@ -214,12 +300,13 @@ cc_test( # for test_suite. tags = ["ops_tests"], deps = [ - ":allocator", ":basics", + ":mat", ":ops", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", + "//compression:test_util", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:thread_pool", @@ -238,12 +325,12 @@ cc_test( "ops_tests", # for test_suite. ], deps = [ - ":allocator", ":basics", ":ops", - ":threading", + ":threading_context", "@googletest//:gtest_main", # buildcleaner: keep "//compression:compress", + "//compression:test_util", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", @@ -252,55 +339,13 @@ cc_test( ], ) -cc_library( - name = "common", - srcs = [ - "gemma/common.cc", - "gemma/configs.cc", - "gemma/tensor_index.cc", - ], - hdrs = [ - "gemma/common.h", - "gemma/configs.h", - "gemma/tensor_index.h", - ], - deps = [ - ":basics", - "//compression:fields", - "//compression:sfp", - "@highway//:hwy", # base.h - ], -) - -cc_test( - name = "configs_test", - srcs = ["gemma/configs_test.cc"], - deps = [ - ":common", - "@googletest//:gtest_main", - "@highway//:hwy", - ], -) - -cc_test( - name = "tensor_index_test", - srcs = ["gemma/tensor_index_test.cc"], - deps = [ - ":basics", - ":common", - ":weights", - "@googletest//:gtest_main", - "//compression:compress", - "@highway//:hwy", - ], -) - cc_library( name = "weights", srcs = ["gemma/weights.cc"], hdrs = ["gemma/weights.h"], deps = [ ":common", + ":mat", "//compression:blob_store", "//compression:compress", "//compression:io", @@ -361,16 +406,17 @@ cc_library( ":basics", ":common", ":ops", + ":mat", ":tokenizer", ":kv_cache", ":weights", ":threading", - "//compression:compress", + ":threading_context", + # Placeholder for internal dep, do not remove., "//compression:io", "//compression:sfp", "//paligemma:image", "@highway//:hwy", - "@highway//:bit_set", "@highway//:nanobenchmark", # timer "@highway//:profiler", "@highway//:thread_pool", @@ -390,25 +436,14 @@ cc_library( ) cc_library( - name = "args", - hdrs = ["util/args.h"], - deps = [ - ":basics", - "//compression:io", - "@highway//:hwy", - ], -) - -cc_library( - name = "app", - hdrs = ["util/app.h"], + name = "gemma_args", + hdrs = ["gemma/gemma_args.h"], deps = [ ":args", ":basics", ":common", ":gemma_lib", ":ops", - ":threading", "//compression:io", "//compression:sfp", "@highway//:hwy", @@ -420,20 +455,15 @@ cc_library( srcs = ["evals/benchmark_helper.cc"], hdrs = ["evals/benchmark_helper.h"], deps = [ - ":app", - ":args", - ":common", ":cross_entropy", + ":gemma_args", ":gemma_lib", - ":kv_cache", ":ops", - ":threading", - # Placeholder for internal dep, do not remove., + ":threading_context", "@google_benchmark//:benchmark", "//compression:compress", "@highway//:hwy", "@highway//:nanobenchmark", - "@highway//:topology", ], ) @@ -451,7 +481,7 @@ cc_test( ":benchmark_helper", ":common", ":gemma_lib", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:hwy", "@highway//:hwy_test_util", ], @@ -470,8 +500,7 @@ cc_test( ":benchmark_helper", ":common", ":gemma_lib", - ":tokenizer", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:hwy", "@highway//:hwy_test_util", ], @@ -481,14 +510,13 @@ cc_binary( name = "gemma", srcs = ["gemma/run.cc"], deps = [ - ":app", ":args", ":benchmark_helper", ":common", + ":gemma_args", ":gemma_lib", ":ops", - ":threading", - # Placeholder for internal dep, do not remove., + ":threading_context", "//compression:sfp", "//paligemma:image", "@highway//:hwy", @@ -594,10 +622,10 @@ cc_library( deps = [ ":allocator", ":common", + ":mat", ":ops", ":prompt", ":weights", - "//compression:compress", "@highway//:dot", "@highway//:hwy", # base.h "@highway//:thread_pool", @@ -614,9 +642,9 @@ cc_library( ], deps = [ ":common", + ":mat", ":prompt", ":weights", - "//compression:compress", "@highway//:hwy", ], ) @@ -631,11 +659,11 @@ cc_test( deps = [ ":backprop_scalar", ":common", + ":mat", ":prompt", ":sampler", ":weights", - "@googletest//:gtest_main", - "//compression:compress", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:thread_pool", ], ) @@ -652,17 +680,16 @@ cc_test( "mem": "28g", }, deps = [ - ":allocator", ":backprop", ":backprop_scalar", ":common", + ":mat", ":ops", ":prompt", ":sampler", - ":threading", + ":threading_context", ":weights", - "@googletest//:gtest_main", - "//compression:compress", + "@googletest//:gtest_main", # buildcleaner: keep "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:thread_pool", @@ -676,6 +703,7 @@ cc_library( deps = [ ":allocator", ":common", + ":mat", ":weights", "//compression:compress", "@highway//:hwy", @@ -685,9 +713,7 @@ cc_library( cc_test( name = "optimize_test", - srcs = [ - "backprop/optimize_test.cc", - ], + srcs = ["backprop/optimize_test.cc"], exec_properties = { # Avoid linker OOMs when building with sanitizer instrumentation. "mem": "28g", @@ -704,7 +730,7 @@ cc_test( ":sampler", ":threading", ":weights", - "@googletest//:gtest_main", + "@googletest//:gtest_main", # buildcleaner: keep "//compression:sfp", "@highway//:thread_pool", ], diff --git a/CMakeLists.txt b/CMakeLists.txt index 1737c2d..b572835 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -75,6 +75,7 @@ set(SOURCES gemma/common.h gemma/configs.cc gemma/configs.h + gemma/gemma_args.h gemma/gemma-inl.h gemma/gemma.cc gemma/gemma.h @@ -102,15 +103,17 @@ set(SOURCES paligemma/image.h util/allocator.cc util/allocator.h - util/app.h - util/args.h util/basics.h + util/mat.cc + util/mat.h util/test_util.h util/threading.cc util/threading.h + util/threading_context.cc + util/threading_context.h util/topology.cc util/topology.h - ) +) if(NOT CMAKE_BUILD_TYPE) set(CMAKE_BUILD_TYPE "Release") @@ -197,8 +200,5 @@ endif() # GEMMA_ENABLE_TESTS ## Tools -add_executable(compress_weights compression/compress_weights.cc) -target_link_libraries(compress_weights libgemma hwy hwy_contrib) - add_executable(migrate_weights compression/migrate_weights.cc) target_link_libraries(migrate_weights libgemma hwy hwy_contrib) diff --git a/backprop/activations.h b/backprop/activations.h index c616759..d0446cd 100644 --- a/backprop/activations.h +++ b/backprop/activations.h @@ -20,24 +20,30 @@ #include -#include "compression/compress.h" // MatStorageT -#include "gemma/configs.h" // ModelConfig +#include "gemma/configs.h" // ModelConfig +#include "util/mat.h" // MatStorageT namespace gcpp { template struct ForwardLayer { ForwardLayer(const LayerConfig& config, size_t seq_len) - : input("input", seq_len, config.model_dim), - pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim), - qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim), - att("att", seq_len * config.heads, seq_len), - att_out("att_out", seq_len * config.heads, config.qkv_dim), - att_post1("att_post1", seq_len, config.model_dim), - attention_out("attention_out", seq_len, config.model_dim), - bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim), - ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2), - ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim), + : input(MakePacked("input", seq_len, config.model_dim)), + pre_att_rms_out( + MakePacked("pre_att_rms_out", seq_len, config.model_dim)), + qkv(MakePacked("qkv", seq_len * (config.heads + 2), config.qkv_dim)), + att(MakePacked("att", seq_len * config.heads, seq_len)), + att_out( + MakePacked("att_out", seq_len * config.heads, config.qkv_dim)), + att_post1(MakePacked("att_post1", seq_len, config.model_dim)), + attention_out( + MakePacked("attention_out", seq_len, config.model_dim)), + bf_pre_ffw_rms_out( + MakePacked("bf_preFF_rms_out", seq_len, config.model_dim)), + ffw_hidden( + MakePacked("ffw_hidden", seq_len, config.ff_hidden_dim * 2)), + ffw_hidden_gated( + MakePacked("ffw_hidden_gated", seq_len, config.ff_hidden_dim)), layer_config(config) {} MatStorageT input; @@ -56,12 +62,12 @@ struct ForwardLayer { template struct ForwardPass { ForwardPass(const ModelConfig& config) - : final_layer_output("final_layer_output", config.seq_len, - config.model_dim), - final_norm_output("final_norm_output", config.seq_len, - config.model_dim), - logits("logits", config.seq_len, config.vocab_size), - probs("probs", config.seq_len, config.vocab_size), + : final_layer_output( + MakePacked("fin_layer_out", config.seq_len, config.model_dim)), + final_norm_output( + MakePacked("fin_norm_out", config.seq_len, config.model_dim)), + logits(MakePacked("logits", config.seq_len, config.vocab_size)), + probs(MakePacked("probs", config.seq_len, config.vocab_size)), weights_config(config) { for (const auto& layer_config : config.layer_configs) { layers.emplace_back(layer_config, config.seq_len); diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 2a0f330..9716d87 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -128,7 +128,7 @@ static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward, HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); }); } -static HWY_NOINLINE void RMSNormVJP( +static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP( const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x, const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens, float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x, @@ -153,10 +153,9 @@ static HWY_NOINLINE void RMSNormVJP( } } -static HWY_NOINLINE void InputEmbeddingVJP( - const float* weights, const std::vector& prompt, - const float scaling, const float* HWY_RESTRICT v, - float* HWY_RESTRICT grad, size_t model_dim) { +static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP( + const float* weights, const std::vector& prompt, const float scaling, + const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) { HWY_ASSERT(!prompt.empty()); for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; @@ -182,17 +181,18 @@ void LayerVJP(const LayerWeightsPtrs& weights, static_cast(1.0 / sqrt(static_cast(qkv_dim))); HWY_ASSERT(num_tokens <= seq_len); - MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(), + MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), next_layer_grad, ff_hidden_dim, model_dim, num_tokens, - grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool); + grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t hidden_offset = pos * ff_hidden_dim * 2; - const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset; + const float* HWY_RESTRICT f_out = + forward.ffw_hidden.Packed() + hidden_offset; const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim; const float* HWY_RESTRICT b_out_gated = - backward.ffw_hidden_gated.data() + pos * ff_hidden_dim; - float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset; + backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim; + float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset; float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; @@ -206,38 +206,39 @@ void LayerVJP(const LayerWeightsPtrs& weights, } } - MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), - backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2, - num_tokens, grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), pool); - RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), - backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens, - grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), - pool); + MatMulVJP(weights.gating_einsum_w.Packed(), + forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), + model_dim, ff_hidden_dim * 2, num_tokens, + grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(), + pool); + RMSNormVJP( + weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), + backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens, + grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(next_layer_grad + pos * model_dim, - backward.attention_out.data() + pos * model_dim, model_dim); + backward.attention_out.Packed() + pos * model_dim, model_dim); } - backward.qkv.ZeroInit(); + ZeroInit(backward.qkv); - MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(), - backward.attention_out.data(), heads, qkv_dim, model_dim, - num_tokens, grad.attn_vec_einsum_w.data(), - backward.att_out.data(), pool); + MultiHeadMatMulVJP( + weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(), + backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens, + grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool); for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t aoffset = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; + const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset; const float* HWY_RESTRICT b_att_out = - backward.att_out.data() + (pos * heads + head) * qkv_dim; - float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; + backward.att_out.Packed() + (pos * heads + head) * qkv_dim; + float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim; - const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs; - float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs; + const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs; + float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs; b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim); MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim); } @@ -247,8 +248,8 @@ void LayerVJP(const LayerWeightsPtrs& weights, for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t aoffset = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset; - float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset; + const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset; + float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset; SoftmaxVJP(f_head_att, b_head_att, pos + 1); } } @@ -257,13 +258,13 @@ void LayerVJP(const LayerWeightsPtrs& weights, for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim; const size_t aoffs = head * seq_len + pos * heads * seq_len; - const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs; - const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs; - float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs; + const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs; + const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs; + float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim; - const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs; - float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs; + const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs; + float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs; MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim); MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim); } @@ -272,28 +273,30 @@ void LayerVJP(const LayerWeightsPtrs& weights, for (int pos = 0; pos < static_cast(num_tokens); ++pos) { float* HWY_RESTRICT b_kv = - backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim; + backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim; Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos); } for (size_t head = 0; head < heads; ++head) { for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT b_q = - backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim; + backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim; MulByConst(query_scale, b_q, qkv_dim); Rope(b_q, qkv_dim, inv_timescale.Const(), -pos); } } - MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), - backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens, - grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool); - RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(), - backward.pre_att_rms_out.data(), model_dim, num_tokens, - grad.pre_attention_norm_scale.data(), backward.input.data(), pool); + MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(), + backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens, + grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(), + pool); + RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(), + backward.pre_att_rms_out.Packed(), model_dim, num_tokens, + grad.pre_attention_norm_scale.Packed(), backward.input.Packed(), + pool); for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(backward.attention_out.data() + pos * model_dim, - backward.input.data() + pos * model_dim, model_dim); + AddFrom(backward.attention_out.Packed() + pos * model_dim, + backward.input.Packed() + pos * model_dim, model_dim); } } @@ -353,47 +356,48 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt, HWY_DASSERT(prompt.context_size < prompt.tokens.size()); const size_t num_tokens = prompt.tokens.size() - 1; - CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, + CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt, kVocabSize); for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftmaxVJP(forward.probs.data() + pos * kVocabSize, - backward.logits.data() + pos * kVocabSize, - kVocabSize); + SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize, + backward.logits.Packed() + pos * kVocabSize, kVocabSize); } if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { - SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize, - backward.logits.data() + pos * kVocabSize, kVocabSize); + SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize, + backward.logits.Packed() + pos * kVocabSize, kVocabSize); } } - MatMulVJP(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), backward.logits.data(), model_dim, - kVocabSize, num_tokens, grad.embedder_input_embedding.data(), - backward.final_norm_output.data(), pool); + MatMulVJP(weights.embedder_input_embedding.Packed(), + forward.final_norm_output.Packed(), backward.logits.Packed(), + model_dim, kVocabSize, num_tokens, + grad.embedder_input_embedding.Packed(), + backward.final_norm_output.Packed(), pool); - RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(), - backward.final_norm_output.data(), model_dim, num_tokens, - grad.final_norm_scale.data(), backward.final_layer_output.data(), - pool); + RMSNormVJP(weights.final_norm_scale.Packed(), + forward.final_layer_output.Packed(), + backward.final_norm_output.Packed(), model_dim, num_tokens, + grad.final_norm_scale.Packed(), + backward.final_layer_output.Packed(), pool); for (int layer = static_cast(kLayers) - 1; layer >= 0; --layer) { auto layer_config = config.layer_configs[layer]; // TODO(szabadka) Implement Griffin layer vjp. HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma); float* next_layer_grad = layer + 1 < kLayers - ? backward.layers[layer + 1].input.data() - : backward.final_layer_output.data(); + ? backward.layers[layer + 1].input.Packed() + : backward.final_layer_output.Packed(); LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, num_tokens, *grad.GetLayer(layer), backward.layers[layer], inv_timescale, pool); } - InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens, - kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), model_dim); + InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens, + kEmbScaling, backward.layers[0].input.Packed(), + grad.embedder_input_embedding.Packed(), model_dim); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/backprop/backward.cc b/backprop/backward.cc index 868b391..d89da45 100644 --- a/backprop/backward.cc +++ b/backprop/backward.cc @@ -17,9 +17,8 @@ #include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/common.h" #include "gemma/weights.h" -#include "util/allocator.h" +#include "util/mat.h" #include "hwy/contrib/thread_pool/thread_pool.h" // Compiles this file for multiple architectures via "foreach_target.h", to diff --git a/backprop/backward.h b/backprop/backward.h index d8e50c7..5a08f5c 100644 --- a/backprop/backward.h +++ b/backprop/backward.h @@ -19,7 +19,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/weights.h" -#include "util/allocator.h" +#include "util/mat.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { diff --git a/backprop/backward_scalar.h b/backprop/backward_scalar.h index b0a37b3..20b43ed 100644 --- a/backprop/backward_scalar.h +++ b/backprop/backward_scalar.h @@ -211,62 +211,65 @@ void LayerVJP(const LayerWeightsPtrs& weights, const size_t kFFHiddenDim = layer_config.ff_hidden_dim; const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim)); - MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy, - grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim, - kFFHiddenDim, num_tokens); + MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy, + grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), + model_dim, kFFHiddenDim, num_tokens); - GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(), - backward.ffw_hidden.data(), kFFHiddenDim, num_tokens); + GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(), + backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens); - MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(), - backward.ffw_hidden.data(), grad.gating_einsum_w.data(), - backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim, + MatMulVJPT(weights.gating_einsum_w.Packed(), + forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(), + grad.gating_einsum_w.Packed(), + backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim, num_tokens); - RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(), - backward.bf_pre_ffw_rms_out.data(), - grad.pre_ffw_norm_scale.data(), backward.attention_out.data(), - model_dim, num_tokens); + RMSNormVJPT( + weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(), + backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(), + backward.attention_out.Packed(), model_dim, num_tokens); - AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim); + AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim); - MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(), - backward.attention_out.data(), - grad.attn_vec_einsum_w.data(), backward.att_out.data(), - kHeads, model_dim, qkv_dim, num_tokens); + MultiHeadMatMulVJPT( + weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(), + backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(), + backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens); - MixByAttentionVJP(forward.qkv.data(), forward.att.data(), - backward.att_out.data(), backward.qkv.data(), - backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len); - - MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads, + MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(), + backward.att_out.Packed(), backward.qkv.Packed(), + backward.att.Packed(), num_tokens, kHeads, qkv_dim, seq_len); - MaskedAttentionVJP(forward.qkv.data(), backward.att.data(), - backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len); + MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens, + kHeads, seq_len); + + MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(), + backward.qkv.Packed(), num_tokens, kHeads, qkv_dim, + seq_len); for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; + T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim; MulByConstT(kQueryScale, qkv, kHeads * qkv_dim); } for (int pos = 0; pos < num_tokens; ++pos) { - T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim; + T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim; for (size_t h = 0; h <= kHeads; ++h) { Rope(qkv + h * qkv_dim, qkv_dim, -pos); } } - MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(), - backward.qkv.data(), grad.qkv_einsum_w.data(), - backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim, - num_tokens); - RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(), - backward.pre_att_rms_out.data(), - grad.pre_attention_norm_scale.data(), backward.input.data(), + MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(), + backward.qkv.Packed(), grad.qkv_einsum_w.Packed(), + backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim, + model_dim, num_tokens); + RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(), + backward.pre_att_rms_out.Packed(), + grad.pre_attention_norm_scale.Packed(), backward.input.Packed(), model_dim, num_tokens); - AddFromT(backward.attention_out.data(), backward.input.data(), + AddFromT(backward.attention_out.Packed(), backward.input.Packed(), num_tokens * model_dim); } @@ -307,41 +310,42 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt, const std::vector tokens = prompt.tokens; const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; - CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt, + CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt, vocab_size); - SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size, + SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size, num_tokens); if (config.final_cap > 0.0f) { for (size_t i = 0; i < num_tokens; ++i) { - SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size, - backward.logits.data() + i * vocab_size, vocab_size); + SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size, + backward.logits.Packed() + i * vocab_size, vocab_size); } } - MatMulVJPT( - weights.embedder_input_embedding.data(), forward.final_norm_output.data(), - backward.logits.data(), grad.embedder_input_embedding.data(), - backward.final_norm_output.data(), vocab_size, model_dim, num_tokens); + MatMulVJPT(weights.embedder_input_embedding.Packed(), + forward.final_norm_output.Packed(), backward.logits.Packed(), + grad.embedder_input_embedding.Packed(), + backward.final_norm_output.Packed(), vocab_size, model_dim, + num_tokens); - RMSNormVJPT(weights.final_norm_scale.data(), - forward.final_layer_output.data(), - backward.final_norm_output.data(), grad.final_norm_scale.data(), - backward.final_layer_output.data(), model_dim, num_tokens); + RMSNormVJPT( + weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(), + backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(), + backward.final_layer_output.Packed(), model_dim, num_tokens); for (int layer = static_cast(layers) - 1; layer >= 0; --layer) { T* next_layer_grad = layer + 1 < layers - ? backward.layers[layer + 1].input.data() - : backward.final_layer_output.data(); + ? backward.layers[layer + 1].input.Packed() + : backward.final_layer_output.Packed(); LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad, *grad.GetLayer(layer), backward.layers[layer], num_tokens); } const T kEmbScaling = EmbeddingScaling(model_dim); - InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens, - kEmbScaling, backward.layers[0].input.data(), - grad.embedder_input_embedding.data(), model_dim); + InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens, + kEmbScaling, backward.layers[0].input.Packed(), + grad.embedder_input_embedding.Packed(), model_dim); } } // namespace gcpp diff --git a/backprop/backward_scalar_test.cc b/backprop/backward_scalar_test.cc index e40f3ed..45d4d18 100644 --- a/backprop/backward_scalar_test.cc +++ b/backprop/backward_scalar_test.cc @@ -31,9 +31,9 @@ #include "backprop/prompt.h" #include "backprop/sampler.h" #include "backprop/test_util.h" -#include "compression/compress.h" #include "gemma/configs.h" #include "gemma/weights.h" +#include "util/mat.h" namespace gcpp { @@ -44,14 +44,14 @@ TEST(BackPropTest, MatMulVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT weights("weights", kRows, kCols); - MatStorageT x("x", kTokens, kCols); - MatStorageT grad("grad", kRows, kCols); - MatStorageT dx("dx", kTokens, kCols); - MatStorageT c_weights("c_weights", kRows, kCols); - MatStorageT c_x("c_x", kTokens, kCols); - MatStorageT c_y("c_y", kTokens, kRows); - MatStorageT dy("dy", kTokens, kRows); + auto weights = MakePacked("weights", kRows, kCols); + auto x = MakePacked("x", kTokens, kCols); + auto grad = MakePacked("grad", kRows, kCols); + auto dx = MakePacked("dx", kTokens, kCols); + auto c_weights = MakePacked("c_weights", kRows, kCols); + auto c_x = MakePacked("c_x", kTokens, kCols); + auto c_y = MakePacked("c_y", kTokens, kRows); + auto dy = MakePacked("dy", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); @@ -60,12 +60,13 @@ TEST(BackPropTest, MatMulVJP) { Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { - MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); + MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols, + kTokens); + return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows); }; - grad.ZeroInit(); - MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), - kRows, kCols, kTokens); + ZeroInit(grad); + MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(), + dx.Packed(), kRows, kCols, kTokens); TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__); TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__); } @@ -79,14 +80,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT weights("weights", kRows, kCols * kHeads); - MatStorageT x("x", kTokens, kCols * kHeads); - MatStorageT grad("grad", kRows, kCols * kHeads); - MatStorageT dx("dx", kTokens, kCols * kHeads); - MatStorageT c_weights("c_weights", kRows, kCols * kHeads); - MatStorageT c_x("c_x", kTokens, kCols * kHeads); - MatStorageT c_y("c_y", kTokens, kRows); - MatStorageT dy("dy", kTokens, kRows); + auto weights = MakePacked("weights", kRows, kCols * kHeads); + auto x = MakePacked("x", kTokens, kCols * kHeads); + auto grad = MakePacked("grad", kRows, kCols * kHeads); + auto dx = MakePacked("dx", kTokens, kCols * kHeads); + auto c_weights = MakePacked("c_weights", kRows, kCols * kHeads); + auto c_x = MakePacked("c_x", kTokens, kCols * kHeads); + auto c_y = MakePacked("c_y", kTokens, kRows); + auto dy = MakePacked("dy", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); @@ -95,13 +96,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) { Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { - MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, - kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); + MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads, + kRows, kCols, kTokens); + return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows); }; - grad.ZeroInit(); - MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), - dx.data(), kHeads, kRows, kCols, kTokens); + ZeroInit(grad); + MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), + grad.Packed(), dx.Packed(), kHeads, kRows, kCols, + kTokens); TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__); TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__); } @@ -113,14 +115,14 @@ TEST(BackPropTest, RMSNormVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT weights("weights", N, 1); - MatStorageT grad("grad", N, 1); - MatStorageT x("x", K, N); - MatStorageT dx("dx", K, N); - MatStorageT dy("dy", K, N); - MatStorageT c_weights("c_weights", N, 1); - MatStorageT c_x("c_x", K, N); - MatStorageT c_y("c_y", K, N); + auto weights = MakePacked("weights", N, 1); + auto grad = MakePacked("grad", N, 1); + auto x = MakePacked("x", K, N); + auto dx = MakePacked("dx", K, N); + auto dy = MakePacked("dy", K, N); + auto c_weights = MakePacked("c_weights", N, 1); + auto c_x = MakePacked("c_x", K, N); + auto c_y = MakePacked("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0 * (1 << iter), gen); @@ -129,12 +131,12 @@ TEST(BackPropTest, RMSNormVJP) { Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); - return DotT(dy.data(), c_y.data(), K * N); + RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K); + return DotT(dy.Packed(), c_y.Packed(), K * N); }; - grad.ZeroInit(); - RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(), - N, K); + ZeroInit(grad); + RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(), + dx.Packed(), N, K); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__); } @@ -145,24 +147,24 @@ TEST(BackPropTest, SoftmaxVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT x("x", N, 1); - MatStorageT dx("dx", N, 1); - MatStorageT dy("dy", N, 1); - MatStorageT c_x("c_x", N, 1); - MatStorageT c_y("c_y", N, 1); + auto x = MakePacked("x", N, 1); + auto dx = MakePacked("dx", N, 1); + auto dy = MakePacked("dy", N, 1); + auto c_x = MakePacked("c_x", N, 1); + auto c_y = MakePacked("c_y", N, 1); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - memcpy(c_y.data(), c_x.data(), c_x.SizeBytes()); - Softmax(c_y.data(), N); - return DotT(dy.data(), c_y.data(), N); + CopyMat(c_x, c_y); + Softmax(c_y.Packed(), N); + return DotT(dy.Packed(), c_y.Packed(), N); }; - Softmax(x.data(), N); - memcpy(dx.data(), dy.data(), dx.SizeBytes()); - SoftmaxVJPT(x.data(), dx.data(), N); + Softmax(x.Packed(), N); + CopyMat(dy, dx); + SoftmaxVJPT(x.Packed(), dx.Packed(), N); TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); } } @@ -175,26 +177,25 @@ TEST(BackPropTest, MaskedSoftmaxVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT x("x", N, 1); - MatStorageT dy("dy", N, 1); - MatStorageT dx("dx", N, 1); - MatStorageT c_x("c_x", N, 1); - MatStorageT c_y("c_y", N, 1); - dx.ZeroInit(); + auto x = MakePacked("x", N, 1); + auto dy = MakePacked("dy", N, 1); + auto dx = MakePacked("dx", N, 1); + auto c_x = MakePacked("c_x", N, 1); + auto c_y = MakePacked("c_y", N, 1); + ZeroInit(dx); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0 * (1 << iter), gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - memcpy(c_y.data(), c_x.data(), - kTokens * kHeads * kSeqLen * sizeof(c_x.At(0))); - MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen); - return DotT(dy.data(), c_y.data(), N); + CopyMat(c_x, c_y); + MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen); + return DotT(dy.Packed(), c_y.Packed(), N); }; - MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen); - memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0))); - MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen); + MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen); + CopyMat(dy, dx); + MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen); TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); } } @@ -204,11 +205,11 @@ TEST(BackPropTest, SoftcapVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT x("x", N, 1); - MatStorageT dx("dx", N, 1); - MatStorageT dy("dy", N, 1); - MatStorageT c_x("c_x", N, 1); - MatStorageT c_y("c_y", N, 1); + auto x = MakePacked("x", N, 1); + auto dx = MakePacked("dx", N, 1); + auto dy = MakePacked("dy", N, 1); + auto c_x = MakePacked("c_x", N, 1); + auto c_y = MakePacked("c_y", N, 1); constexpr float kCap = 30.0f; for (int iter = 0; iter < 10; ++iter) { @@ -216,13 +217,13 @@ TEST(BackPropTest, SoftcapVJP) { Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0))); - Softcap(kCap, c_y.data(), N); - return DotT(dy.data(), c_y.data(), N); + CopyMat(c_x, c_y); + Softcap(kCap, c_y.Packed(), N); + return DotT(dy.Packed(), c_y.Packed(), N); }; - Softcap(kCap, x.data(), N); - memcpy(dx.data(), dy.data(), dx.SizeBytes()); - SoftcapVJPT(kCap, x.data(), dx.data(), N); + Softcap(kCap, x.Packed(), N); + CopyMat(dy, dx); + SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N); TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__); } } @@ -233,9 +234,9 @@ TEST(BackPropTest, CrossEntropyLossGrad) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT x("x", K, V); - MatStorageT dx("dx", K, V); - MatStorageT c_x("c_x", K, V); + auto x = MakePacked("x", K, V); + auto dx = MakePacked("dx", K, V); + auto c_x = MakePacked("c_x", K, V); Prompt prompt; prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 }; @@ -243,13 +244,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) { for (int iter = 0; iter < 10; ++iter) { prompt.context_size = 1 + (iter % 6); RandInit(x, 1.0 * (1 << iter), gen); - Softcap(kCap, x.data(), V * K); - Softmax(x.data(), V, K); - CrossEntropyLossGrad(x.data(), dx.data(), prompt, V); + Softcap(kCap, x.Packed(), V * K); + Softmax(x.Packed(), V, K); + CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V); Complexify(x, c_x); - auto func = [&]() { - return CrossEntropyLoss(c_x.data(), prompt, V); - }; + auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); }; TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__); } } @@ -260,21 +259,21 @@ TEST(BackPropTest, GatedGeluVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT x("x", K, 2 * N); - MatStorageT dx("dx", K, 2 * N); - MatStorageT dy("dy", K, N); - MatStorageT c_x("c_x", K, 2 * N); - MatStorageT c_y("c_y", K, N); + auto x = MakePacked("x", K, 2 * N); + auto dx = MakePacked("dx", K, 2 * N); + auto dy = MakePacked("dy", K, N); + auto c_x = MakePacked("c_x", K, 2 * N); + auto c_y = MakePacked("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0, gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - GatedGelu(c_x.data(), c_y.data(), N, K); - return DotT(dy.data(), c_y.data(), N * K); + GatedGelu(c_x.Packed(), c_y.Packed(), N, K); + return DotT(dy.Packed(), c_y.Packed(), N * K); }; - GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K); + GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K); TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__); } } @@ -289,25 +288,25 @@ TEST(BackPropTest, MaskedAttentionVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT x("x", kQKVSize, 1); - MatStorageT dx("dx", kQKVSize, 1); - MatStorageT dy("dy", kOutSize, 1); - MatStorageT c_x("c_x", kQKVSize, 1); - MatStorageT c_y("c_y", kOutSize, 1); - dx.ZeroInit(); - c_y.ZeroInit(); + auto x = MakePacked("x", kQKVSize, 1); + auto dx = MakePacked("dx", kQKVSize, 1); + auto dy = MakePacked("dy", kOutSize, 1); + auto c_x = MakePacked("c_x", kQKVSize, 1); + auto c_y = MakePacked("c_y", kOutSize, 1); + ZeroInit(dx); + ZeroInit(c_y); for (int iter = 0; iter < 10; ++iter) { RandInit(x, 1.0, gen); Complexify(x, c_x); RandInit(dy, 1.0, gen); auto func = [&]() { - MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim, + MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim, kSeqLen); - return DotT(dy.data(), c_y.data(), kOutSize); + return DotT(dy.Packed(), c_y.Packed(), kOutSize); }; - MaskedAttentionVJP(x.data(), dy.data(), dx.data(), - kTokens, kHeads, kQKVDim, kSeqLen); + MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads, + kQKVDim, kSeqLen); TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__); } } @@ -323,17 +322,17 @@ TEST(BackPropTest, MixByAttentionVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT qkv("qkv", kQKVSize, 1); - MatStorageT dqkv("dqkv", kQKVSize, 1); - MatStorageT attn("attn", kAttnSize, 1); - MatStorageT dattn("dattn", kAttnSize, 1); - MatStorageT dy("dy", kOutSize, 1); - MatStorageT c_qkv("c_qkv", kQKVSize, 1); - MatStorageT c_attn("c_attn", kAttnSize, 1); - MatStorageT c_y("c_y", kOutSize, 1); - dqkv.ZeroInit(); - dattn.ZeroInit(); - c_y.ZeroInit(); + auto qkv = MakePacked("qkv", kQKVSize, 1); + auto dqkv = MakePacked("dqkv", kQKVSize, 1); + auto attn = MakePacked("attn", kAttnSize, 1); + auto dattn = MakePacked("dattn", kAttnSize, 1); + auto dy = MakePacked("dy", kOutSize, 1); + auto c_qkv = MakePacked("c_qkv", kQKVSize, 1); + auto c_attn = MakePacked("c_attn", kAttnSize, 1); + auto c_y = MakePacked("c_y", kOutSize, 1); + ZeroInit(dqkv); + ZeroInit(dattn); + ZeroInit(c_y); for (int iter = 0; iter < 10; ++iter) { RandInit(qkv, 1.0, gen); @@ -342,12 +341,12 @@ TEST(BackPropTest, MixByAttentionVJP) { Complexify(attn, c_attn); RandInit(dy, 1.0, gen); auto func = [&]() { - MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(), - kTokens, kHeads, kQKVDim, kSeqLen); - return DotT(dy.data(), c_y.data(), kOutSize); + MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens, + kHeads, kQKVDim, kSeqLen); + return DotT(dy.Packed(), c_y.Packed(), kOutSize); }; - MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(), - dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen); + MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(), + dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen); TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__); TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__); } @@ -360,11 +359,11 @@ TEST(BackPropTest, InputEmbeddingVJP) { std::mt19937 gen(42); using T = double; using TC = std::complex; - MatStorageT weights("weights", kVocabSize, kModelDim); - MatStorageT grad("grad", kVocabSize, kModelDim); - MatStorageT dy("dy", kSeqLen, kModelDim); - MatStorageT c_weights("c_weights", kVocabSize, kModelDim); - MatStorageT c_y("c_y", kSeqLen, kModelDim); + auto weights = MakePacked("weights", kVocabSize, kModelDim); + auto grad = MakePacked("grad", kVocabSize, kModelDim); + auto dy = MakePacked("dy", kSeqLen, kModelDim); + auto c_weights = MakePacked("c_weights", kVocabSize, kModelDim); + auto c_y = MakePacked("c_y", kSeqLen, kModelDim); std::vector tokens = { 0, 1, 2, 3, 0, 1, 2 }; size_t num_tokens = tokens.size() - 1; @@ -373,12 +372,13 @@ TEST(BackPropTest, InputEmbeddingVJP) { RandInit(dy, 1.0, gen); Complexify(weights, c_weights); auto func = [&]() { - InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim); - return DotT(dy.data(), c_y.data(), num_tokens * kModelDim); + InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(), + kModelDim); + return DotT(dy.Packed(), c_y.Packed(), num_tokens * kModelDim); }; - grad.ZeroInit(); - InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(), - kModelDim); + ZeroInit(grad); + InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(), + grad.Packed(), kModelDim); TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__); } } @@ -410,8 +410,7 @@ TEST(BackPropTest, LayerVJP) { using T = double; using TC = std::complex; ModelConfig config = TestConfig(); - TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1, - /*reshape_att=*/false); + const TensorIndex tensor_index = TensorIndexLLM(config, size_t{0}); const size_t kOutputSize = config.seq_len * config.model_dim; LayerWeightsPtrs weights(config.layer_configs[0], tensor_index); LayerWeightsPtrs grad(config.layer_configs[0], tensor_index); @@ -419,15 +418,15 @@ TEST(BackPropTest, LayerVJP) { ForwardLayer backward(config.layer_configs[0], config.seq_len); LayerWeightsPtrs c_weights(config.layer_configs[0], tensor_index); ForwardLayer c_forward(config.layer_configs[0], config.seq_len); - MatStorageT y("y", kOutputSize, 1); - MatStorageT dy("dy", kOutputSize, 1); - MatStorageT c_y("c_y", kOutputSize, 1); + auto y = MakePacked("y", kOutputSize, 1); + auto dy = MakePacked("dy", kOutputSize, 1); + auto c_y = MakePacked("c_y", kOutputSize, 1); const size_t num_tokens = 3; - std::vector layer_storage; + std::vector layer_storage; weights.Allocate(layer_storage); grad.Allocate(layer_storage); c_weights.Allocate(layer_storage); - backward.input.ZeroInit(); + ZeroInit(backward.input); for (size_t iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0, gen); @@ -436,12 +435,12 @@ TEST(BackPropTest, LayerVJP) { Complexify(weights, c_weights); Complexify(forward.input, c_forward.input); auto func = [&]() { - ApplyLayer(c_weights, c_forward, num_tokens, c_y.data()); - return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim); + ApplyLayer(c_weights, c_forward, num_tokens, c_y.Packed()); + return DotT(dy.Packed(), c_y.Packed(), num_tokens * config.model_dim); }; grad.ZeroInit(/*layer_idx=*/0); - ApplyLayer(weights, forward, num_tokens, y.data()); - LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens); + ApplyLayer(weights, forward, num_tokens, y.Packed()); + LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens); TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11, __LINE__); TestGradient(grad, c_weights, func, 1e-11); diff --git a/backprop/backward_test.cc b/backprop/backward_test.cc index f1c97b2..865f481 100644 --- a/backprop/backward_test.cc +++ b/backprop/backward_test.cc @@ -33,8 +33,10 @@ #include "backprop/test_util.h" #include "gemma/configs.h" #include "ops/ops.h" -#include "util/threading.h" +#include "util/mat.h" +#include "util/threading_context.h" #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" // clang-format off #undef HWY_TARGET_INCLUDE @@ -46,33 +48,45 @@ // After highway.h #include "backprop/backward-inl.h" #include "backprop/forward-inl.h" -#include "compression/compress.h" #include "ops/ops-inl.h" -#include "util/allocator.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { +hwy::ThreadPool& ThreadHostileGetPool() { + // Assume this is only called at the top level, i.e. not in a thread. Then we + // can safely call `SetArgs` only once, because it would assert otherwise. + // This is preferable to calling `ThreadHostileInvalidate`, because we would + // repeat the topology initialization for every test. + if (!ThreadingContext2::IsInitialized()) { + gcpp::ThreadingArgs threading_args; + threading_args.max_packages = 1; + threading_args.max_clusters = 8; + threading_args.pin = Tristate::kFalse; + ThreadingContext2::SetArgs(threading_args); + } + return ThreadingContext2::Get().pools.Pool(); +} + void TestMatMulVJP() { static const size_t kRows = 8; static const size_t kCols = 64; static const size_t kTokens = 5; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); + + hwy::ThreadPool& pool = ThreadHostileGetPool(); std::mt19937 gen(42); - MatStorageT weights("weights", kRows, kCols); - MatStorageT x("x", kTokens, kCols); - MatStorageT dy("dy", kTokens, kRows); - MatStorageT grad("grad", kRows, kCols); - MatStorageT dx("dx", kTokens, kCols); - MatStorageT grad_scalar("grad_scalar", kRows, kCols); - MatStorageT dx_scalar("dx_scalar", kTokens, kCols); + auto weights = MakePacked("weights", kRows, kCols); + auto x = MakePacked("x", kTokens, kCols); + auto dy = MakePacked("dy", kTokens, kRows); + auto grad = MakePacked("grad", kRows, kCols); + auto dx = MakePacked("dx", kTokens, kCols); + auto grad_scalar = MakePacked("grad_scalar", kRows, kCols); + auto dx_scalar = MakePacked("dx_scalar", kTokens, kCols); using TC = std::complex; - MatStorageT c_weights("c_weights", kRows, kCols); - MatStorageT c_x("c_x", kTokens, kCols); - MatStorageT c_y("c_y", kTokens, kRows); + auto c_weights = MakePacked("c_weights", kRows, kCols); + auto c_x = MakePacked("c_x", kTokens, kCols); + auto c_y = MakePacked("c_y", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0f * (1 << iter), gen); @@ -81,19 +95,20 @@ void TestMatMulVJP() { Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { - MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); + MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols, + kTokens); + return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows); }; - grad.ZeroInit(); - MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens, - grad.data(), dx.data(), pools.Pool()); + ZeroInit(grad); + MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens, + grad.Packed(), dx.Packed(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - grad_scalar.ZeroInit(); - MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), - dx_scalar.data(), kRows, kCols, kTokens); + ZeroInit(grad_scalar); + MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(), + dx_scalar.Packed(), kRows, kCols, kTokens); TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__); TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); } @@ -104,21 +119,19 @@ void TestMultiHeadMatMulVJP() { static const size_t kCols = 16; static const size_t kHeads = 4; static const size_t kTokens = 3; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); + hwy::ThreadPool& pool = ThreadHostileGetPool(); std::mt19937 gen(42); - MatStorageT weights("weights", kRows, kCols * kHeads); - MatStorageT x("x", kTokens, kCols * kHeads); - MatStorageT grad("grad", kRows, kCols * kHeads); - MatStorageT dx("dx", kTokens, kCols * kHeads); - MatStorageT dy("dy", kTokens, kRows); - MatStorageT grad_scalar("grad_scalar", kRows, kCols * kHeads); - MatStorageT dx_scalar("dx_scalar", kTokens, kCols * kHeads); + auto weights = MakePacked("weights", kRows, kCols * kHeads); + auto x = MakePacked("x", kTokens, kCols * kHeads); + auto grad = MakePacked("grad", kRows, kCols * kHeads); + auto dx = MakePacked("dx", kTokens, kCols * kHeads); + auto dy = MakePacked("dy", kTokens, kRows); + auto grad_scalar = MakePacked("grad_scalar", kRows, kCols * kHeads); + auto dx_scalar = MakePacked("dx_scalar", kTokens, kCols * kHeads); using TC = std::complex; - MatStorageT c_weights("c_weights", kRows, kCols * kHeads); - MatStorageT c_x("c_x", kTokens, kCols * kHeads); - MatStorageT c_y("c_y", kTokens, kRows); + auto c_weights = MakePacked("c_weights", kRows, kCols * kHeads); + auto c_x = MakePacked("c_x", kTokens, kCols * kHeads); + auto c_y = MakePacked("c_y", kTokens, kRows); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0f * (1 << iter), gen); @@ -127,20 +140,21 @@ void TestMultiHeadMatMulVJP() { Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { - MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows, - kCols, kTokens); - return DotT(dy.data(), c_y.data(), kTokens * kRows); + MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads, + kRows, kCols, kTokens); + return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows); }; - grad.ZeroInit(); - MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols, - kRows, kTokens, grad.data(), dx.data(), pools.Pool()); + ZeroInit(grad); + MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols, + kRows, kTokens, grad.Packed(), dx.Packed(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - grad_scalar.ZeroInit(); - MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), - dx_scalar.data(), kHeads, kRows, kCols, kTokens); + ZeroInit(grad_scalar); + MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), + grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows, + kCols, kTokens); TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__); TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__); } @@ -149,21 +163,19 @@ void TestMultiHeadMatMulVJP() { void TestRMSNormVJP() { static const size_t K = 2; static const size_t N = 64; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); + hwy::ThreadPool& pool = ThreadHostileGetPool(); std::mt19937 gen(42); - MatStorageT weights("weights", N, 1); - MatStorageT x("x", K, N); - MatStorageT grad("grad", N, 1); - MatStorageT dx("dx", K, N); - MatStorageT dy("dy", K, N); - MatStorageT grad_scalar("grad_scalar", N, 1); - MatStorageT dx_scalar("dx_scalar", K, N); + auto weights = MakePacked("weights", N, 1); + auto x = MakePacked("x", K, N); + auto grad = MakePacked("grad", N, 1); + auto dx = MakePacked("dx", K, N); + auto dy = MakePacked("dy", K, N); + auto grad_scalar = MakePacked("grad_scalar", N, 1); + auto dx_scalar = MakePacked("dx_scalar", K, N); using TC = std::complex; - MatStorageT c_weights("c_weights", N, 1); - MatStorageT c_x("c_x", K, N); - MatStorageT c_y("c_y", K, N); + auto c_weights = MakePacked("c_weights", N, 1); + auto c_x = MakePacked("c_x", K, N); + auto c_y = MakePacked("c_y", K, N); for (int iter = 0; iter < 10; ++iter) { RandInit(weights, 1.0f * (1 << iter), gen); @@ -172,19 +184,19 @@ void TestRMSNormVJP() { Complexify(weights, c_weights); Complexify(x, c_x); auto func = [&]() { - RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K); - return DotT(dy.data(), c_y.data(), K * N); + RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K); + return DotT(dy.Packed(), c_y.Packed(), K * N); }; - grad.ZeroInit(); - RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(), - dx.data(), pools.Pool()); + ZeroInit(grad); + RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(), + dx.Packed(), pool); TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__); TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__); - grad_scalar.ZeroInit(); - RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(), - dx_scalar.data(), N, K); + ZeroInit(grad_scalar); + RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(), + dx_scalar.Packed(), N, K); TestNear(dx, dx_scalar, 0, 2e-5, __LINE__); TestNear(grad, grad_scalar, 0, 2e-5, __LINE__); } @@ -215,9 +227,7 @@ static ModelConfig TestConfig() { void TestEndToEnd() { std::mt19937 gen(42); - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1)); - Allocator::Init(topology); - gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); + hwy::ThreadPool& pool = ThreadHostileGetPool(); ModelConfig config = TestConfig(); WeightsWrapper weights(config); WeightsWrapper grad(config); @@ -232,7 +242,7 @@ void TestEndToEnd() { std::vector batch = training_task.SampleBatch(3, gen); RowVectorBatch inv_timescale = CreateInvTimescale( - config.layer_configs[0].qkv_dim, + ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); for (const Prompt& prompt : batch) { ReverseSequenceSampler::LogPrompt(prompt); @@ -242,13 +252,13 @@ void TestEndToEnd() { float loss1 = CrossEntropyLossForwardPass( prompt.tokens, prompt.context_size, weights.get(), forward1, - inv_timescale, pools.Pool()); + inv_timescale, pool); EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5); grad.ZeroInit(); CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(), - backward, inv_timescale, pools.Pool()); + backward, inv_timescale, pool); Complexify(weights.get(), c_weights.get()); auto func = [&]() { diff --git a/backprop/common_scalar.h b/backprop/common_scalar.h index c61086d..9794636 100644 --- a/backprop/common_scalar.h +++ b/backprop/common_scalar.h @@ -20,7 +20,7 @@ #include -#include "compression/compress.h" // MatStorageT +#include "util/mat.h" namespace gcpp { @@ -60,7 +60,9 @@ void MulByConstAndAddT(T c, const T* x, T* out, size_t N) { template void MulByConstAndAddT(T c, const MatPtrT& x, MatPtrT& out) { - MulByConstAndAddT(c, x.data(), out.data(), x.NumElements()); + for (size_t r = 0; r < x.Rows(); ++r) { + MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols()); + } } template diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index ca969c4..75de9a2 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -28,6 +28,7 @@ #include "gemma/configs.h" #include "gemma/weights.h" #include "util/allocator.h" +#include "util/mat.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -50,16 +51,17 @@ HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { -template -void InputEmbedding(const ArrayT& weights, const std::vector& prompt, +template +void InputEmbedding(const MatPtrT& weights, const std::vector& prompt, const float scaling, float* HWY_RESTRICT output, size_t model_dim, size_t vocab_size) { const hn::ScalableTag df; HWY_ASSERT(!prompt.empty()); for (size_t pos = 0; pos < prompt.size() - 1; ++pos) { int token = prompt[pos]; - DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size), - token * model_dim, output + pos * model_dim, + const auto span = weights.Span(); + HWY_ASSERT(span.num == model_dim * vocab_size); + DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim, model_dim); MulByConst(scaling, output + pos * model_dim, model_dim); } @@ -109,27 +111,27 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, static_cast(1.0 / sqrt(static_cast(kQKVDim))); HWY_ASSERT(num_tokens <= kSeqLen); - ApplyRMSNorm(weights.pre_attention_norm_scale.data(), - activations.input.data(), model_dim, num_tokens, - activations.pre_att_rms_out.data(), pool); + ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(), + activations.input.Packed(), model_dim, num_tokens, + activations.pre_att_rms_out.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim, - activations.pre_att_rms_out.data() + pos * model_dim, - activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); + activations.pre_att_rms_out.Packed() + pos * model_dim, + activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool); } const size_t num_tasks = kHeads * num_tokens; for (size_t pos = 0; pos < num_tokens; ++pos) { float* HWY_RESTRICT k = - activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim; + activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim; Rope(k, kQKVDim, inv_timescale.Const(), pos); } pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t pos = task / kHeads; float* HWY_RESTRICT q = - activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim; Rope(q, kQKVDim, inv_timescale.Const(), pos); MulByConst(query_scale, q, kQKVDim); }); @@ -138,12 +140,12 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, const size_t head = task % kHeads; const size_t pos = task / kHeads; const float* HWY_RESTRICT q = - activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim; + activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim; float* HWY_RESTRICT head_att = - activations.att.data() + (pos * kHeads + head) * kSeqLen; + activations.att.Packed() + (pos * kHeads + head) * kSeqLen; for (size_t pos2 = 0; pos2 <= pos; ++pos2) { const float* HWY_RESTRICT k2 = - activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; + activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim; const float score = Dot(q, k2, kQKVDim); head_att[pos2] = score; } @@ -153,7 +155,7 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, const size_t head = task % kHeads; const size_t pos = task / kHeads; float* HWY_RESTRICT head_att = - activations.att.data() + (pos * kHeads + head) * kSeqLen; + activations.att.Packed() + (pos * kHeads + head) * kSeqLen; Softmax(head_att, pos + 1); }); @@ -161,51 +163,51 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, const size_t head = task % kHeads; const size_t pos = task / kHeads; const float* HWY_RESTRICT head_att = - activations.att.data() + (pos * kHeads + head) * kSeqLen; + activations.att.Packed() + (pos * kHeads + head) * kSeqLen; float* HWY_RESTRICT att_out = - activations.att_out.data() + (pos * kHeads + head) * kQKVDim; + activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim; hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out)); for (size_t pos2 = 0; pos2 <= pos; ++pos2) { - float* HWY_RESTRICT v2 = - activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; + float* HWY_RESTRICT v2 = activations.qkv.Packed() + + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim; MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim); } }); - activations.attention_out.ZeroInit(); + ZeroInit(activations.attention_out); for (size_t pos = 0; pos < num_tokens; ++pos) { for (size_t head = 0; head < kHeads; ++head) { - MatVec( - weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim, - kQKVDim, - activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, - activations.att_post1.data() + pos * model_dim, pool); - AddFrom(activations.att_post1.data() + pos * model_dim, - activations.attention_out.data() + pos * model_dim, model_dim); + MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim, + kQKVDim, + activations.att_out.Packed() + pos * kHeads * kQKVDim + + head * kQKVDim, + activations.att_post1.Packed() + pos * model_dim, pool); + AddFrom(activations.att_post1.Packed() + pos * model_dim, + activations.attention_out.Packed() + pos * model_dim, model_dim); } } for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.input.data() + pos * model_dim, - activations.attention_out.data() + pos * model_dim, model_dim); + AddFrom(activations.input.Packed() + pos * model_dim, + activations.attention_out.Packed() + pos * model_dim, model_dim); } - ApplyRMSNorm(weights.pre_ffw_norm_scale.data(), - activations.attention_out.data(), model_dim, num_tokens, - activations.bf_pre_ffw_rms_out.data(), pool); + ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(), + activations.attention_out.Packed(), model_dim, num_tokens, + activations.bf_pre_ffw_rms_out.Packed(), pool); const size_t kFFHiddenDim = config.ff_hidden_dim; for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim, - activations.bf_pre_ffw_rms_out.data() + pos * model_dim, - activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); + activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim, + activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { const size_t hidden_offset = pos * kFFHiddenDim * 2; const float* HWY_RESTRICT out = - activations.ffw_hidden.data() + hidden_offset; + activations.ffw_hidden.Packed() + hidden_offset; const float* HWY_RESTRICT out_mul = out + kFFHiddenDim; float* HWY_RESTRICT out_gated = - activations.ffw_hidden_gated.data() + pos * kFFHiddenDim; + activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim; namespace hn = hwy::HWY_NAMESPACE; using DF = hn::ScalableTag; DF df; @@ -217,11 +219,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs& weights, } for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim, - activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, + activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim, output + pos * model_dim, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { - AddFrom(activations.attention_out.data() + pos * model_dim, + AddFrom(activations.attention_out.Packed() + pos * model_dim, output + pos * model_dim, model_dim); } } @@ -247,44 +249,43 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, const size_t num_tokens = prompt.size() - 1; InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling, - forward.layers[0].input.data(), model_dim, vocab_size); + forward.layers[0].input.Packed(), model_dim, vocab_size); for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) { auto type = config.layer_configs[layer].type; // TODO(szabadka) Implement Griffin layer. HWY_ASSERT(type == LayerAttentionType::kGemma); float* HWY_RESTRICT output = layer + 1 < layers - ? forward.layers[layer + 1].input.data() - : forward.final_layer_output.data(); + ? forward.layers[layer + 1].input.Packed() + : forward.final_layer_output.Packed(); ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, output, inv_timescale, pool); } - ApplyRMSNorm(weights.final_norm_scale.data(), - forward.final_layer_output.data(), model_dim, num_tokens, - forward.final_norm_output.data(), pool); + ApplyRMSNorm(weights.final_norm_scale.Packed(), + forward.final_layer_output.Packed(), model_dim, num_tokens, + forward.final_norm_output.Packed(), pool); for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim, - forward.final_norm_output.data() + pos * model_dim, - forward.logits.data() + pos * vocab_size, pool); + forward.final_norm_output.Packed() + pos * model_dim, + forward.logits.Packed() + pos * vocab_size, pool); } if (config.final_cap > 0.0f) { for (size_t pos = 0; pos < num_tokens; ++pos) { - LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size, - vocab_size); + LogitsSoftCap(config.final_cap, + forward.logits.Packed() + pos * vocab_size, vocab_size); } } - hwy::CopyBytes(forward.logits.data(), forward.probs.data(), - num_tokens * vocab_size * sizeof(forward.logits.At(0))); + CopyMat(forward.logits, forward.probs); for (size_t pos = 0; pos < num_tokens; ++pos) { - Softmax(forward.probs.data() + pos * vocab_size, vocab_size); + Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size); } - return CrossEntropyLoss(forward.probs.data(), prompt, context_size, + return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size, vocab_size, pool); } diff --git a/backprop/forward.cc b/backprop/forward.cc index 0c6cc5c..8f85e81 100644 --- a/backprop/forward.cc +++ b/backprop/forward.cc @@ -17,9 +17,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" -#include "gemma/common.h" -#include "gemma/configs.h" -#include "util/allocator.h" +#include "util/mat.h" #include "hwy/contrib/thread_pool/thread_pool.h" // Compiles this file for multiple architectures via "foreach_target.h", to diff --git a/backprop/forward_scalar.h b/backprop/forward_scalar.h index 617d0c3..d81ae30 100644 --- a/backprop/forward_scalar.h +++ b/backprop/forward_scalar.h @@ -180,54 +180,59 @@ void ApplyLayer(const LayerWeightsPtrs& weights, const size_t ff_hidden_dim = layer_config.ff_hidden_dim; static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim)); - RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(), - activations.pre_att_rms_out.data(), model_dim, num_tokens); + RMSNormT(weights.pre_attention_norm_scale.Packed(), + activations.input.Packed(), activations.pre_att_rms_out.Packed(), + model_dim, num_tokens); - MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(), - activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens); + MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(), + activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim, + num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; + T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim; for (size_t h = 0; h <= heads; ++h) { Rope(qkv + h * qkv_dim, qkv_dim, pos); } } for (size_t pos = 0; pos < num_tokens; ++pos) { - T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim; + T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim; MulByConstT(query_scale, qkv, heads * qkv_dim); } - MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens, - heads, qkv_dim, seq_len); + MaskedAttention(activations.qkv.Packed(), activations.att.Packed(), + num_tokens, heads, qkv_dim, seq_len); - MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len); + MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len); - MixByAttention(activations.qkv.data(), activations.att.data(), - activations.att_out.data(), num_tokens, heads, qkv_dim, + MixByAttention(activations.qkv.Packed(), activations.att.Packed(), + activations.att_out.Packed(), num_tokens, heads, qkv_dim, seq_len); - MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(), - activations.attention_out.data(), heads, model_dim, qkv_dim, + MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(), + activations.att_out.Packed(), + activations.attention_out.Packed(), heads, model_dim, qkv_dim, num_tokens); - AddFromT(activations.input.data(), activations.attention_out.data(), + AddFromT(activations.input.Packed(), activations.attention_out.Packed(), num_tokens * model_dim); - RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(), - activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens); + RMSNormT(weights.pre_ffw_norm_scale.Packed(), + activations.attention_out.Packed(), + activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens); - MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(), - activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim, + MatMulT(weights.gating_einsum_w.Packed(), + activations.bf_pre_ffw_rms_out.Packed(), + activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim, num_tokens); - GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(), - ff_hidden_dim, num_tokens); + GatedGelu(activations.ffw_hidden.Packed(), + activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens); - MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output, - model_dim, ff_hidden_dim, num_tokens); + MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(), + output, model_dim, ff_hidden_dim, num_tokens); - AddFromT(activations.attention_out.data(), output, num_tokens * model_dim); + AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim); } template @@ -258,35 +263,35 @@ T CrossEntropyLossForwardPass(const Prompt& prompt, const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1; const T kEmbScaling = EmbeddingScaling(model_dim); - InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling, - forward.layers[0].input.data(), model_dim); + InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling, + forward.layers[0].input.Packed(), model_dim); for (size_t layer = 0; layer < layers; ++layer) { - T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data() - : forward.final_layer_output.data(); + T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed() + : forward.final_layer_output.Packed(); ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens, output); } - RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(), - forward.final_norm_output.data(), model_dim, num_tokens); + RMSNormT(weights.final_norm_scale.Packed(), + forward.final_layer_output.Packed(), + forward.final_norm_output.Packed(), model_dim, num_tokens); - MatMulT(weights.embedder_input_embedding.data(), - forward.final_norm_output.data(), forward.logits.data(), vocab_size, - model_dim, num_tokens); + MatMulT(weights.embedder_input_embedding.Packed(), + forward.final_norm_output.Packed(), forward.logits.Packed(), + vocab_size, model_dim, num_tokens); for (size_t pos = 0; pos < num_tokens; ++pos) { if (config.final_cap > 0.0f) { - Softcap(config.final_cap, forward.logits.data() + pos * vocab_size, + Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size, vocab_size); } } - memcpy(forward.probs.data(), forward.logits.data(), - num_tokens * vocab_size * sizeof(forward.logits.At(0))); - Softmax(forward.probs.data(), vocab_size, num_tokens); + CopyMat(forward.logits, forward.probs); + Softmax(forward.probs.Packed(), vocab_size, num_tokens); - return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size); + return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size); } } // namespace gcpp diff --git a/backprop/optimize_test.cc b/backprop/optimize_test.cc index 6f08bf0..93335bc 100644 --- a/backprop/optimize_test.cc +++ b/backprop/optimize_test.cc @@ -41,11 +41,14 @@ namespace gcpp { TEST(OptimizeTest, GradientDescent) { - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1)); - Allocator::Init(topology); - NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse); - MatMulEnv env(topology, pools); - hwy::ThreadPool& pool = pools.Pool(); + gcpp::ThreadingArgs threading_args; + threading_args.max_packages = 1; + threading_args.max_clusters = 1; + threading_args.pin = Tristate::kFalse; + ThreadingContext2::SetArgs(threading_args); + MatMulEnv env(ThreadingContext2::Get()); + const Allocator2& allocator = env.ctx.allocator; + hwy::ThreadPool& pool = env.ctx.pools.Pool(); std::mt19937 gen(42); const ModelInfo info = { @@ -64,7 +67,7 @@ TEST(OptimizeTest, GradientDescent) { KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16); RowVectorBatch inv_timescale = CreateInvTimescale( - config.layer_configs[0].qkv_dim, + allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); Gemma gemma(GemmaTokenizer(), info, env); diff --git a/backprop/optimizer.cc b/backprop/optimizer.cc index 9187bf7..2eac992 100644 --- a/backprop/optimizer.cc +++ b/backprop/optimizer.cc @@ -18,9 +18,9 @@ #include #include "compression/compress.h" -#include "gemma/common.h" #include "gemma/weights.h" #include "util/allocator.h" +#include "util/mat.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -39,11 +39,11 @@ class AdamUpdater { void operator()(const char* name, const MatPtr& grad, MatPtr& weights, MatPtr& grad_m, MatPtr& grad_v) { - const float* HWY_RESTRICT g = grad.data(); - float* HWY_RESTRICT w = weights.data(); - float* HWY_RESTRICT m = grad_m.data(); - float* HWY_RESTRICT v = grad_v.data(); - for (size_t i = 0; i < grad.NumElements(); ++i) { + const float* HWY_RESTRICT g = grad.RowT(0); + float* HWY_RESTRICT w = weights.RowT(0); + float* HWY_RESTRICT m = grad_m.RowT(0); + float* HWY_RESTRICT v = grad_v.RowT(0); + for (size_t i = 0; i < grad.Extents().Area(); ++i) { m[i] *= beta1_; m[i] += cbeta1_ * g[i]; v[i] *= beta2_; diff --git a/backprop/test_util.h b/backprop/test_util.h index a83e3d5..f5aa4fd 100644 --- a/backprop/test_util.h +++ b/backprop/test_util.h @@ -24,21 +24,13 @@ #include #include "gtest/gtest.h" -#include "compression/compress.h" #include "gemma/configs.h" #include "gemma/weights.h" +#include "util/mat.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { -template -void RandInit(MatPtrT& x, T stddev, std::mt19937& gen) { - std::normal_distribution dist(0.0, stddev); - for (size_t i = 0; i < x.NumElements(); ++i) { - x.At(i) = dist(gen); - } -} - // TODO: make a member of Layer. template void RandInit(LayerWeightsPtrs& w, T stddev, std::mt19937& gen) { @@ -62,8 +54,12 @@ void RandInit(ModelWeightsPtrs& w, T stddev, std::mt19937& gen) { template void Complexify(const MatPtrT& x, MatPtrT>& c_x) { - for (size_t i = 0; i < x.NumElements(); ++i) { - c_x.At(i) = std::complex(x.At(i), 0.0); + for (size_t r = 0; r < x.Rows(); ++r) { + const T* row = x.Row(r); + std::complex* c_row = c_x.Row(r); + for (size_t c = 0; c < x.Cols(); ++c) { + c_row[c] = std::complex(row[c], 0.0); + } } } @@ -87,14 +83,14 @@ void Complexify(const ModelWeightsPtrs& w, ModelWeightsPtrs& c_w) { } } -// Somewhat duplicates ModelWeightsStorage, but that has neither double nor +// Somewhat duplicates WeightsOwner, but that has neither double nor // complex types allowed and it would cause code bloat to add them there. template class WeightsWrapper { public: explicit WeightsWrapper(const ModelConfig& config) : pool_(0), weights_(config) { - weights_.Allocate(data_, pool_); + weights_.Allocate(owners_, pool_); } const ModelWeightsPtrs& get() const { return weights_; } @@ -106,7 +102,7 @@ class WeightsWrapper { private: hwy::ThreadPool pool_; - std::vector data_; + std::vector owners_; ModelWeightsPtrs weights_; }; @@ -116,13 +112,18 @@ void TestNear(const MatPtrT& actual, const MatPtrT& expected, double sum0 = 0; double sum1 = 0; double sum01 = 0; - for (size_t i = 0; i < actual.NumElements(); ++i) { - sum0 += actual.At(i) * actual.At(i); - sum1 += expected.At(i) * expected.At(i); - sum01 += actual.At(i) * expected.At(i); - ASSERT_NEAR(actual.At(i), expected.At(i), - std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err)) - << "line: " << line << " dim=" << expected.NumElements() << " i=" << i; + for (size_t r = 0; r < actual.Rows(); ++r) { + const T* actual_row = actual.Row(r); + const U* expected_row = expected.Row(r); + for (size_t c = 0; c < actual.Cols(); ++c) { + sum0 += actual_row[c] * actual_row[c]; + sum1 += expected_row[c] * expected_row[c]; + sum01 += actual_row[c] * expected_row[c]; + ASSERT_NEAR( + actual_row[c], expected_row[c], + std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err)) + << "line: " << line << " r " << r << " c " << c; + } } if (sum0 > 1e-40) { double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1); @@ -148,15 +149,19 @@ void TestNear(const MatPtrT& actual, const MatPtrT& expected, template void TestGradient(const MatPtrT& grad, MatPtrT>& x, FUNC func, U step, T max_abs_err, T max_rel_err, int line) { - MatStorageT exp_grad("exp_grad", x.Rows(), x.Cols()); + MatStorageT exp_grad = MakePacked("exp_grad", x.Rows(), x.Cols()); const U inv_step = 1.0 / step; - for (size_t i = 0; i < x.NumElements(); ++i) { - const U x0 = std::real(x.At(i)); - const std::complex x1 = std::complex(x0, step); - x.At(i) = x1; - const std::complex f1 = func(); - exp_grad.At(i) = std::imag(f1) * inv_step; - x.At(i) = x0; + for (size_t r = 0; r < x.Rows(); ++r) { + std::complex* x_row = x.Row(r); + T* exp_row = exp_grad.Row(r); + for (size_t c = 0; c < x.Cols(); ++c) { + const U x0 = std::real(x_row[c]); + const std::complex x1 = std::complex(x0, step); + x_row[c] = x1; + const std::complex f1 = func(); + exp_row[c] = std::imag(f1) * inv_step; + x_row[c] = x0; + } } TestNear(grad, exp_grad, max_abs_err, max_rel_err, line); } diff --git a/compression/BUILD.bazel b/compression/BUILD.bazel index f12ca59..e5102fe 100644 --- a/compression/BUILD.bazel +++ b/compression/BUILD.bazel @@ -146,6 +146,7 @@ cc_library( ":distortion", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:thread_pool", ], ) @@ -209,6 +210,7 @@ cc_library( "//:allocator", "//:basics", "//:common", + "//:mat", "@highway//:hwy", "@highway//:nanobenchmark", "@highway//:stats", @@ -252,22 +254,6 @@ cc_library( ], ) -cc_binary( - name = "compress_weights", - srcs = ["compress_weights.cc"], - deps = [ - ":compress", - ":io", - "//:allocator", - "//:args", - "//:common", - "//:tokenizer", - "//:weights", - "@highway//:hwy", - "@highway//:thread_pool", - ], -) - cc_binary( name = "blob_compare", srcs = ["blob_compare.cc"], @@ -277,9 +263,11 @@ cc_binary( "//:allocator", "//:basics", "//:threading", + "//:threading_context", "@highway//:hwy", "@highway//:hwy_test_util", "@highway//:nanobenchmark", + "@highway//:thread_pool", ], ) @@ -287,7 +275,6 @@ cc_binary( name = "migrate_weights", srcs = ["migrate_weights.cc"], deps = [ - "//:app", "//:args", "//:benchmark_helper", "//:gemma_lib", diff --git a/compression/blob_compare.cc b/compression/blob_compare.cc index c0fe63c..4e465ca 100644 --- a/compression/blob_compare.cc +++ b/compression/blob_compare.cc @@ -25,8 +25,10 @@ #include "util/allocator.h" #include "util/basics.h" // IndexRange #include "util/threading.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/timer.h" namespace gcpp { @@ -202,15 +204,13 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) { if (!CompareKeys(reader1, reader2)) return; // Single allocation, avoid initializing the memory. - BoundedTopology topology; - Allocator::Init(topology); - NestedPools pools(topology); const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2); BytePtr all_blobs = hwy::AllocateAligned(total_bytes); size_t pos = 0; BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos); BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos); + NestedPools& pools = ThreadingContext2::Get().pools; ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools); CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools); diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 8638b5f..0c0cdef 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -29,6 +29,7 @@ #include "compression/compress.h" // IWYU pragma: export #include "compression/distortion.h" #include "gemma/configs.h" +#include "util/mat.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -379,7 +380,7 @@ struct CompressTraits { using Packed = SfpStream; // Callers are responsible for scaling `raw` such that its magnitudes do not - // exceed `SfpStream::kMax`. See CompressedArray::scale(). + // exceed `SfpStream::kMax`. See CompressedArray::Scale(). template static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw, size_t num, CompressPerThread& tls, @@ -522,8 +523,7 @@ HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num, CompressWorkingSet& work, MatStorageT& compressed, hwy::ThreadPool& pool) { - Compress(raw, num, work, - MakeSpan(compressed.data(), compressed.NumElements()), + Compress(raw, num, work, compressed.Span(), /*packed_ofs=*/0, pool); } @@ -717,11 +717,9 @@ class Compressor { template void operator()(MatPtrT* compressed, const char* decorated_name, const float* HWY_RESTRICT weights) { - size_t num_weights = compressed->NumElements(); - if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr) - return; - size_t num_compressed = compressed->NumElements(); - PackedSpan packed = MakeSpan(compressed->data(), num_compressed); + size_t num_weights = compressed->Extents().Area(); + if (num_weights == 0 || weights == nullptr || !compressed->HasPtr()) return; + PackedSpan packed = compressed->Span(); fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name, num_weights / (1000 * 1000)); Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0, diff --git a/compression/compress.cc b/compression/compress.cc index e858e15..1818b8f 100644 --- a/compression/compress.cc +++ b/compression/compress.cc @@ -17,6 +17,6 @@ namespace gcpp { -MatPtr::~MatPtr() {} +// TODO: move ScaleWeights here. } // namespace gcpp diff --git a/compression/compress.h b/compression/compress.h index d875c4b..8844601 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -41,7 +41,8 @@ // IWYU pragma: end_exports #include "gemma/configs.h" #include "util/allocator.h" -#include "hwy/per_target.h" +#include "util/mat.h" +#include "hwy/contrib/thread_pool/thread_pool.h" #if COMPRESS_STATS #include "compression/distortion.h" #include "hwy/stats.h" @@ -49,322 +50,6 @@ namespace gcpp { -// Base class for rank-1 or 2 tensors (vector or matrix). -// Supports both dynamic and compile-time sizing. -// Holds metadata and a non-owning pointer to the data, owned by the derived -// MatStorageT class. -// This class also provides easy conversion from/to a table of contents for a -// BlobStore file, and a templated (compile-time) accessor for a 2-d array of -// fixed inner dimension and type. -// It is designed to be put in a vector, and has default copy and operator=, so -// it is easy to read/write a blob_store file. -class MatPtr : public IFields { - public: - // Full constructor for dynamic sizing. - MatPtr(const std::string& name, Type type, size_t element_size, size_t rows, - size_t cols) - : name_(name), - type_(type), - element_size_(element_size), - num_elements_(rows * cols), - rows_(rows), - cols_(cols), - ptr_(nullptr) { - stride_ = cols; - } - // Default is to leave all fields default-initialized. - MatPtr() = default; - virtual ~MatPtr(); - - // Compatibility interface for CompressedArray. - // TODO: remove. - template - T* data() { - return HWY_RCAST_ALIGNED(T*, ptr_); - } - template - const T* data() const { - return HWY_RCAST_ALIGNED(const T*, ptr_); - } - - const void* Ptr() const { return ptr_; } - void* Ptr() { return ptr_; } - // Sets the pointer from another MatPtr. - void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; } - - // Copying allowed as the metadata is small. - MatPtr(const MatPtr& other) = default; - MatPtr& operator=(const MatPtr& other) = default; - - // Returns the name of the blob. - const char* Name() const override { return name_.c_str(); } - void SetName(const std::string& name) { name_ = name; } - - // Returns the type of the blob. - Type GetType() const { return type_; } - - // Returns the size of each element in bytes. - size_t ElementSize() const { return element_size_; } - - // Returns the number of elements in the array. - size_t NumElements() const { return num_elements_; } - - // Returns the number of bytes in the array. - size_t SizeBytes() const { - if (this->GetType() == TypeEnum()) { - return NuqStream::PackedEnd(num_elements_); - } - return num_elements_ * element_size_; - } - - // Returns the number of rows in the 2-d array (outer dimension). - size_t Rows() const { return rows_; } - - // Returns the number of columns in the 2-d array (inner dimension). - size_t Cols() const { return cols_; } - - Extents2D Extents() const { return Extents2D(rows_, cols_); } - - // Currently same as cols, but may differ in the future. This is the offset by - // which to advance pointers to the next row. - size_t Stride() const { return stride_; } - - // Decoded elements should be multiplied by this to restore their original - // range. This is required because SfpStream can only encode a limited range - // of magnitudes. - float scale() const { return scale_; } - void set_scale(float scale) { scale_ = scale; } - - std::string LayerName(int layer) const { - std::string name = name_ + std::to_string(layer); - HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t)); - return name; - } - - // Sets all data to zero. - void ZeroInit() { - if (ptr_ == nullptr) - HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str()); - hwy::ZeroBytes(ptr_, SizeBytes()); - } - - void VisitFields(IFieldsVisitor& visitor) override { - visitor(name_); - visitor(type_); - visitor(element_size_); - visitor(num_elements_); - visitor(rows_); - visitor(cols_); - visitor(scale_); - visitor(stride_); - } - - // Calls func on the upcasted type. Since MatPtr by design is not templated, - // here we provide a way to get to the derived type, provided that `Type()` - // is one of the strings returned by `TypeName()`. - template - decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args); - - protected: - // Arbitrary name for the array of preferably <= 16 characters. - std::string name_; - // Should be the result of TypeEnum for CallUpcasted() to work. - Type type_; - // sizeof(T) - uint32_t element_size_ = 0; - // Number of elements in the array. - uint32_t num_elements_ = 0; // In element_size units. - // Number of rows in the 2-d array (outer dimension). - uint32_t rows_ = 0; - // Number of columns in the 2-d array (inner dimension). - uint32_t cols_ = 0; - // Scaling to apply to each element. - float scale_ = 1.0f; - // Aligned data array. This is always a borrowed pointer. It should never be - // freed. The underlying memory is owned by a subclass or some external class - // and must outlive this object. - void* ptr_ = nullptr; - - uint32_t stride_; -}; - -// MatPtrT adds a single template argument to MatPtr for an explicit type. -// Use this class as a function argument where the type needs to be known. -// Use MatPtr where the type does not need to be known. -template -class MatPtrT : public MatPtr { - public: - // Full constructor for dynamic sizing. - MatPtrT(const std::string& name, size_t rows, size_t cols) - : MatPtr(name, TypeEnum(), sizeof(MatT), rows, cols) {} - // Construction from TensorIndex entry to remove duplication of sizes. - MatPtrT(const std::string& name, const TensorIndex& tensor_index) - : MatPtrT(name, tensor_index.FindName(name)) {} - MatPtrT(const std::string& name, const TensorInfo* tensor) - : MatPtr(name, TypeEnum(), sizeof(MatT), 0, 0) { - if (tensor == nullptr) { - cols_ = 0; - rows_ = 0; - } else { - cols_ = tensor->shape.back(); - rows_ = 1; - if (tensor->cols_take_extra_dims) { - // The columns eat the extra dimensions. - rows_ = tensor->shape[0]; - for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { - cols_ *= tensor->shape[i]; - } - } else { - // The rows eat the extra dimensions. - for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { - rows_ *= tensor->shape[i]; - } - } - } - stride_ = cols_; - num_elements_ = rows_ * cols_; - } - - // Copying allowed as the metadata is small. - MatPtrT(const MatPtr& other) : MatPtr(other) {} - MatPtrT& operator=(const MatPtr& other) { - MatPtr::operator=(other); - return *this; - } - MatPtrT(const MatPtrT& other) = default; - MatPtrT& operator=(const MatPtrT& other) = default; - - std::string CacheName(int layer = -1, char separator = ' ', - int index = -1) const { - // Already used/retired: s, S, n, 1 - const char prefix = hwy::IsSame() ? 'F' - : hwy::IsSame() ? 'B' - : hwy::IsSame() ? '$' - : hwy::IsSame() ? '2' - : '?'; - std::string name = std::string(1, prefix) + name_; - if (layer >= 0 || index >= 0) { - name += '_'; - if (layer >= 0) name += std::to_string(layer); - if (index >= 0) { - name += separator + std::to_string(index); - } - } - return name; - } - - // Sets the number of elements in the array. For use when the number of - // elements is != rows * cols ONLY. - void SetNumElements(size_t num_elements) { - num_elements_ = CompressedArrayElements(num_elements); - } - - // 2-d Accessor for a specific type but with a dynamic inner dimension. - template - const T& At(size_t row, size_t col) const { - size_t index = row * cols_ + col; - HWY_DASSERT(index < num_elements_); - return HWY_RCAST_ALIGNED(const T*, ptr_)[index]; - } - - // 1-d Accessor for a specific type. - // TODO: replace this with a Foreach(), or at least a ForEachRow(). - const MatT& At(size_t index) const { - HWY_DASSERT(index < num_elements_); - return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index]; - } - MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; } - - // Compatibility interface for CompressedArray. - // TODO: remove - template - T* data() { - return HWY_RCAST_ALIGNED(T*, ptr_); - } - template - const T* data() const { - return HWY_RCAST_ALIGNED(const T*, ptr_); - } - // The const accessor data_scale1() asserts (!) that the scale is 1.0f, so - // calling it means "I am sure the scale is 1 and therefore ignore the scale". - // A scale of 0 indicates that the scale has likely never been set, so is - // "implicitly 1". - const MatT* data_scale1() const { - HWY_ASSERT(scale() == 1.f); - return HWY_RCAST_ALIGNED(const MatT*, ptr_); - } -}; - -template -decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) { - if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else if (type_ == TypeEnum()) { - return func(dynamic_cast*>(this), - std::forward(args)...); - } else { - HWY_ABORT("Type %d unknown.", type_); - } -} - -// MatStorageT adds the actual data storage to MatPtrT. -// TODO: use Extents2D instead of rows and cols. -template -class MatStorageT : public MatPtrT { - public: - // Full constructor for dynamic sizing. - MatStorageT(const std::string& name, size_t rows, size_t cols) - : MatPtrT(name, rows, cols) { - Allocate(); - } - // Can copy the metadata, from a MatPtr, and allocate later. - MatStorageT(const MatPtr& other) : MatPtrT(other) {} - ~MatStorageT() = default; - - // Move-only because this contains a unique_ptr. - MatStorageT(const MatStorageT& other) = delete; - MatStorageT& operator=(const MatStorageT& other) = delete; - MatStorageT(MatStorageT&& other) = default; - MatStorageT& operator=(MatStorageT&& other) = default; - - // Allocate the memory and copy the pointer to the MatPtr. - // num_elements is in elements. In the default (zero) case, it is computed - // from the current num_elements_ which was set by the constructor from the - // rows and cols. - void Allocate(size_t num_elements = 0) { - if (num_elements == 0) { - num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT)); - } else { - this->num_elements_ = num_elements; - } - // Pad to allow overrunning the last row by 2 BF16 vectors, hence at most - // `2 * VectorBytes / sizeof(BF16)` elements of MatT. - const size_t padding = hwy::VectorBytes(); - data_ = Allocator::Alloc(num_elements + padding); - hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT)); - this->ptr_ = data_.get(); - } - - // Zeros the content. - void ZeroInit() { - HWY_ASSERT(data_ != nullptr); - hwy::ZeroBytes(data_.get(), this->SizeBytes()); - } - - private: - AlignedPtr data_; -}; - -// MatStorage allows heterogeneous tensors to be stored in a single vector. -using MatStorage = MatStorageT; - // Table of contents for a blob store file. Full metadata, but not actual data. class BlobToc { public: @@ -389,7 +74,7 @@ class BlobToc { blob.Read(hwy::Span(toc), consumed); prev_consumed = consumed; consumed = result.pos; - if (blob.NumElements() > 0) { + if (!blob.IsEmpty()) { AddToToc(blob); } } @@ -503,10 +188,11 @@ class WriteToBlobStore { explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {} template - void operator()(MatPtrT* compressed, const char* decorated_name) { - if (compressed->Ptr() == nullptr) return; - writer_.Add(MakeKey(decorated_name), compressed->Ptr(), - compressed->SizeBytes()); + void operator()(MatPtrT* compressed, + const char* decorated_name) const { + if (!compressed->HasPtr()) return; + writer_.Add(MakeKey(decorated_name), compressed->Packed(), + compressed->PackedBytes()); MatPtr renamed_tensor(*compressed); renamed_tensor.SetName(decorated_name); renamed_tensor.AppendTo(toc_); @@ -519,9 +205,8 @@ class WriteToBlobStore { void AddScales(const float* scales, size_t len) { if (len) { - MatPtrT scales_ptr("scales", 0, 1); - writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales, - len * sizeof(scales[0])); + MatPtrT scales_ptr("scales", Extents2D(0, 1)); + writer_.Add(MakeKey(scales_ptr.Name()), scales, len * sizeof(scales[0])); } } @@ -554,9 +239,9 @@ class WriteToBlobStore { hwy::ThreadPool& pool_; private: - std::vector toc_; - BlobWriter writer_; - std::vector config_buffer_; + mutable std::vector toc_; + mutable BlobWriter writer_; + mutable std::vector config_buffer_; }; // Functor called for each tensor, which loads them and their scaling factors @@ -613,6 +298,7 @@ class ReadFromBlobStore { // Called for each tensor, enqueues read requests. void operator()(const char* name, hwy::Span tensors) { if (file_toc_.Empty() || file_toc_.Contains(name)) { + HWY_ASSERT(tensors[0]); model_toc_.push_back(tensors[0]); file_keys_.push_back(name); } @@ -622,15 +308,15 @@ class ReadFromBlobStore { for (size_t i = 0; i < len; ++i) { scales[i] = 1.0f; } - MatPtrT scales_ptr("scales", 0, 1); - auto key = MakeKey(scales_ptr.CacheName().c_str()); + MatPtrT scales_ptr("scales", Extents2D(0, 1)); + auto key = MakeKey(scales_ptr.Name()); if (reader_.BlobSize(key) == 0) return 0; return reader_.Enqueue(key, scales, len * sizeof(scales[0])); } // Returns whether all tensors are successfully loaded from cache. BlobError ReadAll(hwy::ThreadPool& pool, - std::vector& model_memory) { + std::vector& model_memory) { // reader_ invalid or any Enqueue failed if (err_ != 0) return err_; // Setup the model_memory. @@ -650,26 +336,27 @@ class ReadFromBlobStore { } std::string name = blob->Name(); *blob = *toc_blob; - blob->SetName(name); + blob->SetName(name.c_str()); } - model_memory.emplace_back(*blob); - model_memory.back().SetName(file_key); + model_memory.push_back(MatOwner()); } // Allocate in parallel using the pool. pool.Run(0, model_memory.size(), [this, &model_memory](uint64_t task, size_t /*thread*/) { - model_memory[task].Allocate(); - model_toc_[task]->SetPtr(model_memory[task]); + model_memory[task].AllocateFor(*model_toc_[task], + MatPadding::kPacked); }); // Enqueue the read requests. - for (auto& blob : model_memory) { - err_ = - reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes()); + for (size_t b = 0; b < model_toc_.size(); ++b) { + err_ = reader_.Enqueue(MakeKey(file_keys_[b].c_str()), + model_toc_[b]->RowT(0), + model_toc_[b]->PackedBytes()); if (err_ != 0) { - fprintf(stderr, - "Failed to read blob %s (error %d) of size %zu x %zu x %zu\n", - blob.Name(), err_, blob.Rows(), blob.Cols(), - blob.ElementSize()); + fprintf( + stderr, + "Failed to read blob %s (error %d) of size %zu x %zu, type %d\n", + file_keys_[b].c_str(), err_, model_toc_[b]->Rows(), + model_toc_[b]->Cols(), static_cast(model_toc_[b]->GetType())); return err_; } } diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc deleted file mode 100644 index cbf7e35..0000000 --- a/compression/compress_weights.cc +++ /dev/null @@ -1,286 +0,0 @@ -// Copyright 2024 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// https://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Command line tool to create compressed weights. - -// Compiles this file for multiple architectures via "foreach_target.h", to -// which we pass the filename via macro 'argument'. -#undef HWY_TARGET_INCLUDE -#define HWY_TARGET_INCLUDE \ - "compression/compress_weights.cc" // NOLINT -#include "hwy/foreach_target.h" // IWYU pragma: keep -#include "hwy/highway.h" -// After highway.h -#include "compression/compress-inl.h" -#include "gemma/configs.h" -#include "gemma/tokenizer.h" - -#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE -#define GEMMA_COMPRESS_WEIGHTS_ONCE - -#include -#include - -#include // std::clamp -#include -#include -#include -#include // NOLINT -#include - -#include "compression/compress.h" -#include "compression/io.h" // Path -#include "compression/shared.h" // PromptWrapping -#include "gemma/common.h" // Model -#include "gemma/weights.h" -#include "util/allocator.h" -#include "util/args.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" - -namespace gcpp { - -namespace { - -} // namespace - -struct Args : public ArgsBase { - static constexpr size_t kDefaultNumThreads = ~size_t{0}; - - void ChooseNumThreads() { - if (num_threads == kDefaultNumThreads) { - // This is a rough heuristic, replace with something better in the future. - num_threads = static_cast(std::clamp( - static_cast(std::thread::hardware_concurrency()) - 2, 1, 18)); - } - } - - public: - Args(int argc, char* argv[]) { - InitAndParse(argc, argv); - ChooseNumThreads(); - } - - // Returns error string or nullptr if OK. - const char* Validate() { - if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_, - prompt_wrapping_)) { - return err; - } - if (const char* err = ParseType(weight_type_str, weight_type_)) { - return err; - } - if (weights.path.empty()) { - return "Missing --weights flag, a file for the uncompressed model."; - } - if (compressed_weights.path.empty()) { - return "Missing --compressed_weights flag, a file for the compressed " - "model."; - } - if (!weights.Exists()) { - return "Can't open file specified with --weights flag."; - } - return nullptr; - } - - Path weights; // uncompressed weights file location - Path compressed_weights; // compressed weights file location - std::string model_type_str; - std::string weight_type_str; - size_t num_threads; - // If non-empty, whether to include the config and TOC in the output file, as - // well as the tokenizer. - Path tokenizer; - - template - void ForEach(const Visitor& visitor) { - visitor(weights, "weights", Path(), - "Path to model weights (.bin) file.\n" - " Required argument."); - visitor(model_type_str, "model", std::string(), - "Model type\n 2b-it = 2B parameters, instruction-tuned\n " - "2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters " - "instruction-tuned\n 7b-pt = 7B parameters, pretrained\n " - "gr2b-it = griffin 2B parameters, instruction-tuned\n " - "gr2b-pt = griffin 2B parameters, pretrained\n " - " Required argument."); - visitor(weight_type_str, "weight_type", std::string("sfp"), - "Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n" - " Required argument."); - visitor(compressed_weights, "compressed_weights", Path(), - "Path name where compressed weights (.sbs) file will be written.\n" - " Required argument."); - visitor(num_threads, "num_threads", - kDefaultNumThreads, // see ChooseNumThreads - "Number of threads to use.\n Default = Estimate of the " - "number of supported concurrent threads.", - 2); - visitor(tokenizer, "tokenizer", Path(), - "Path to tokenizer file. If given, the config and TOC are also " - "added to the output file."); - } - - // Uninitialized before Validate, must call after that. - gcpp::Model ModelType() const { return model_type_; } - gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; } - gcpp::Type WeightType() const { return weight_type_; } - - private: - Model model_type_; - PromptWrapping prompt_wrapping_; - Type weight_type_; -}; - -void ShowHelp(gcpp::Args& args) { - std::cerr - << "Usage:\n./compress_weights --weights " - " --model --compressed_weights \n"; - std::cerr << "\n*Arguments*\n\n"; - args.Help(); - std::cerr << "\n"; -} - -} // namespace gcpp -#endif // GEMMA_COMPRESS_WEIGHTS_ONCE - -// SIMD code, compiled once per target. -HWY_BEFORE_NAMESPACE(); -namespace gcpp { -namespace HWY_NAMESPACE { - -template -void CompressWeights(const Path& weights_path, - const Path& compressed_weights_path, Model model_type, - Type weight_type, PromptWrapping wrapping, - const Path& tokenizer_path, hwy::ThreadPool& pool) { - if (!weights_path.Exists()) { - HWY_ABORT("The model weights file '%s' does not exist.", - weights_path.path.c_str()); - } - printf("Compressing weights from %s to %s\n", weights_path.path.c_str(), - compressed_weights_path.path.c_str()); - ModelConfig config = ConfigFromModel(model_type); - config.weight = weight_type; - config.wrapping = wrapping; - std::vector model_storage; - ModelWeightsPtrs c_weights(config); - c_weights.Allocate(model_storage, pool); - ModelWeightsPtrs uc_weights(config); - uc_weights.Allocate(model_storage, pool); - // Get uncompressed weights, compress, and store. - FILE* fptr = fopen(weights_path.path.c_str(), "rb"); - if (fptr == nullptr) { - HWY_ABORT("Failed to open model file %s - does it exist?", - weights_path.path.c_str()); - } - bool ok = true; - uint64_t total_size = 0; - ModelWeightsPtrs::ForEachTensor( - {&uc_weights}, ForEachType::kLoadNoToc, - [&](const char* name, hwy::Span tensors) { - fprintf(stderr, "Loading Parameters (size %zu): %s\n", - tensors[0]->SizeBytes(), name); - ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr); - total_size += tensors[0]->SizeBytes(); - }); - if (!tokenizer_path.path.empty()) { - uc_weights.AllocAndCopyWithTranspose(pool, model_storage); - } - const bool scale_for_compression = config.num_tensor_scales > 0; - std::vector scales; - if (scale_for_compression) { - uc_weights.GetOrApplyScales(scales); - } - Compressor compressor(pool); - ModelWeightsPtrs::ForEachTensor( - {reinterpret_cast*>(&uc_weights), &c_weights}, - tokenizer_path.path.empty() ? ForEachType::kLoadNoToc - : ForEachType::kLoadWithToc, - [&compressor](const char* name, hwy::Span tensors) { - tensors[1]->CallUpcasted( - compressor, name, - reinterpret_cast(tensors[0]->Ptr())); - }); - if (!tokenizer_path.path.empty()) { - std::string tokenizer_proto = ReadFileToString(tokenizer_path); - compressor.AddTokenizer(tokenizer_proto); - } else { - compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0])); - } - compressor.WriteAll(compressed_weights_path, - tokenizer_path.path.empty() ? nullptr : &config); -} - -} // namespace HWY_NAMESPACE -} // namespace gcpp -HWY_AFTER_NAMESPACE(); - -#if HWY_ONCE -namespace gcpp { - -void Run(Args& args) { - hwy::ThreadPool pool(args.num_threads); - if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) { - HWY_ABORT("PaliGemma is not supported in compress_weights."); - } - const Model model_type = args.ModelType(); - const Type weight_type = args.WeightType(); - switch (weight_type) { - case Type::kF32: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - case Type::kBF16: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - case Type::kSFP: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - case Type::kNUQ: - HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights) - (args.weights, args.compressed_weights, model_type, weight_type, - args.PromptWrappingType(), args.tokenizer, pool); - break; - default: - HWY_ABORT("Weight type %d unsupported.", static_cast(weight_type)); - } -} - -} // namespace gcpp - -int main(int argc, char** argv) { - gcpp::Args args(argc, argv); - - if (gcpp::HasHelp(argc, argv)) { - gcpp::ShowHelp(args); - return 0; - } - - if (const char* error = args.Validate()) { - gcpp::ShowHelp(args); - HWY_ABORT("\nInvalid args: %s", error); - } - - gcpp::Run(args); - - return 0; -} - -#endif // HWY_ONCE diff --git a/compression/migrate_weights.cc b/compression/migrate_weights.cc index 97e6343..fea1ee5 100644 --- a/compression/migrate_weights.cc +++ b/compression/migrate_weights.cc @@ -57,6 +57,6 @@ int main(int argc, char** argv) { } gcpp::GemmaEnv env(argc, argv); hwy::ThreadPool pool(0); - env.GetModel()->Save(args.output_weights, pool); + env.GetGemma()->Save(args.output_weights, pool); return 0; } diff --git a/compression/python/BUILD.bazel b/compression/python/BUILD.bazel index 8bfb391..b2b376b 100644 --- a/compression/python/BUILD.bazel +++ b/compression/python/BUILD.bazel @@ -16,7 +16,9 @@ cc_library( deps = [ "@abseil-cpp//absl/types:span", "//:common", + "//:mat", "//:tokenizer", + "//:weights", "//compression:compress", "//compression:io", "@highway//:hwy", @@ -30,7 +32,6 @@ pybind_extension( deps = [ ":compression_clif_aux", "@abseil-cpp//absl/types:span", - "//:common", "//compression:sfp", ], ) diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 2705756..d9c2750 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -22,7 +22,8 @@ #include "compression/compress.h" #include "compression/shared.h" -#include "hwy/aligned_allocator.h" +#include "gemma/weights.h" +#include "util/mat.h" #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE \ @@ -81,30 +82,23 @@ class SbsWriterImpl : public WriterInterface { template void AllocateAndCompress(const std::string& name, absl::Span weights) { - MatPtrT storage(name, 1, weights.size()); - model_memory_.push_back(storage); - model_memory_.back().Allocate(); - storage.SetPtr(model_memory_.back()); - std::string decorated_name = storage.CacheName(); + MatPtrT storage(name.c_str(), Extents2D(1, weights.size())); + model_memory_.push_back(MatOwner()); + model_memory_.back().AllocateFor(storage, MatPadding::kPacked); + std::string decorated_name = CacheName(storage); compressor_(&storage, decorated_name.c_str(), weights.data()); } template void AllocateWithShape(const std::string& name, absl::Span weights, const TensorInfo& tensor_info, float scale) { - MatPtrT storage(name, &tensor_info); - storage.set_scale(scale); + MatPtrT storage(name.c_str(), &tensor_info); + storage.SetScale(scale); - // Don't reset num_elements for NUQ. - if (!hwy::IsSame, NuqStream>()) { - storage.SetNumElements(CompressedArrayElements(weights.size())); - } - - model_memory_.push_back(storage); + model_memory_.push_back(MatOwner()); if (mode_ == CompressorMode::kTEST_ONLY) return; - model_memory_.back().Allocate(); - storage.SetPtr(model_memory_.back()); - std::string decorated_name = storage.CacheName(); + model_memory_.back().AllocateFor(storage, MatPadding::kPacked); + std::string decorated_name = CacheName(storage); compressor_(&storage, decorated_name.c_str(), weights.data()); } @@ -176,7 +170,7 @@ class SbsWriterImpl : public WriterInterface { hwy::ThreadPool pool_; Compressor compressor_; CompressWorkingSet working_set_; - std::vector model_memory_; + std::vector model_memory_; std::vector scales_; CompressorMode mode_; }; diff --git a/compression/shared.h b/compression/shared.h index a5c87ae..8b6fb82 100644 --- a/compression/shared.h +++ b/compression/shared.h @@ -201,15 +201,24 @@ inline bool EnumValid(PromptWrapping type) { // Tensor types for loading weights. Note that not all types are supported as // weights for a model, but can be used for other purposes, such as types for -// ModelWeightsPtrs. When adding a new type that is supported, also +// `WeightsPtrs`. When adding a new type that is supported, also // update gemma.cc, weights.*, and add instantiations/new_one.cc. enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 }; -constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp", - "nuq", "f64", "c64", "u128"}; +static constexpr const char* kTypeStrings[] = { + "unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"}; +static constexpr size_t kNumTypes = + sizeof(kTypeStrings) / sizeof(kTypeStrings[0]); +static constexpr size_t kTypeBits[] = {0, + 8 * sizeof(float), + 8 * sizeof(BF16), + 8 * sizeof(SfpStream), + 4 /* NuqStream, actually 4.5 */, + 8 * sizeof(double), + 8 * sizeof(std::complex), + 8 * sizeof(hwy::uint128_t)}; -inline bool EnumValid(Type type) { - return static_cast(type) >= 0 && - static_cast(type) <= static_cast(Type::kU128); +static inline bool EnumValid(Type type) { + return static_cast(type) < kNumTypes; } // Returns a Type enum for the type of the template parameter. @@ -236,10 +245,16 @@ Type TypeEnum() { } } -// Returns a string name for the type of the template parameter. +static inline size_t TypeBits(Type type) { + return kTypeBits[static_cast(type)]; +} + +static inline const char* TypeName(Type type) { + return kTypeStrings[static_cast(type)]; +} template const char* TypeName() { - return kTypeStrings[static_cast(TypeEnum())]; + return TypeName(TypeEnum()); } template @@ -248,7 +263,9 @@ constexpr bool IsCompressed() { } // Returns the number of `MatT` elements required to store `capacity` values, -// which must not be zero. +// which must not be zero. This is only intended to support the extra tables +// required for NUQ. `capacity` includes any padding and is `rows * stride`. +// Deprecated, replaced by fixup within `MatPtr`. Only used by tests. template constexpr size_t CompressedArrayElements(size_t capacity) { if constexpr (hwy::IsSame, NuqStream>()) { diff --git a/compression/test_util-inl.h b/compression/test_util-inl.h index 860644a..f7a887f 100644 --- a/compression/test_util-inl.h +++ b/compression/test_util-inl.h @@ -18,10 +18,13 @@ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ // IWYU pragma: begin_exports -#include "compression/compress.h" #include "compression/distortion.h" +#include "util/mat.h" // IWYU pragma: end_exports +#include "compression/compress.h" +#include "hwy/contrib/thread_pool/thread_pool.h" + #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_ // Include guard for (potentially) SIMD code. @@ -62,6 +65,52 @@ void ForeachPackedAndRawType() { ForeachRawType(); } +// Generates inputs: deterministic, within max SfpStream range. +template +MatStorageT GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + MatStorageT raw("raw", extents, MatPadding::kPacked); + MatStorageT compressed("mat", extents, MatPadding::kPacked); + const float scale = SfpStream::kMax / extents.Area(); + pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { + float* HWY_RESTRICT row = raw.Row(r); + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(r * extents.cols + c) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + row[c] = f; + } + }); + + Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(), + /*packed_ofs=*/0, pool); + compressed.SetScale(0.6f); // Arbitrary value, different from 1. + return compressed; +} + +// `extents` describes the transposed matrix. +template +MatStorageT GenerateTransposedMat(const Extents2D extents, + hwy::ThreadPool& pool) { + gcpp::CompressWorkingSet ws; + MatStorageT raw("raw", extents, MatPadding::kPacked); + MatStorageT compressed("trans", extents, MatPadding::kPacked); + const float scale = SfpStream::kMax / extents.Area(); + pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { + float* HWY_RESTRICT row = raw.Row(r); + for (size_t c = 0; c < extents.cols; c++) { + float f = static_cast(c * extents.rows + r) * scale; + if ((r + c) & 1) f = -f; // Also generate some negative values. + row[c] = f; + } + }); + + Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(), + /*packed_ofs=*/0, pool); + // Arbitrary value, different from 1, must match `GenerateMat`. + compressed.SetScale(0.6f); + return compressed; +} + // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp diff --git a/evals/benchmark.cc b/evals/benchmark.cc index 8682189..579a64f 100644 --- a/evals/benchmark.cc +++ b/evals/benchmark.cc @@ -128,10 +128,10 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text, size_t num_tokens = std::min(prompt.size() - pos, batch_tokens); std::vector prompt_slice(prompt.begin() + pos, prompt.begin() + pos + num_tokens); - KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(), + KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(), env.MutableConfig().prefill_tbatch_size); float entropy = ComputeCrossEntropy( - *env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); + *env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity()); total_entropy += entropy; LogSpeedStats(time_start, pos + num_tokens); std::string text_slice = env.StringFromTokens(prompt_slice); @@ -186,8 +186,8 @@ int main(int argc, char** argv) { if (!benchmark_args.goldens.Empty()) { const std::string golden_path = benchmark_args.goldens.path + "/" + - gcpp::ModelString(env.GetModel()->Info().model, - env.GetModel()->Info().wrapping) + + gcpp::ModelString(env.GetGemma()->Info().model, + env.GetGemma()->Info().wrapping) + ".txt"; return BenchmarkGoldens(env, golden_path); } else if (!benchmark_args.summarize_text.Empty()) { diff --git a/evals/benchmark_helper.cc b/evals/benchmark_helper.cc index 2daebdf..82eda29 100644 --- a/evals/benchmark_helper.cc +++ b/evals/benchmark_helper.cc @@ -18,27 +18,20 @@ #include #include -#include #include -#include #include #include #include #include -// Placeholder for internal header, do not modify. -#include "compression/compress.h" // TypeName +#include "compression/shared.h" // TypeName #include "evals/cross_entropy.h" -#include "gemma/common.h" // StringFromType #include "gemma/gemma.h" -#include "gemma/kv_cache.h" -#include "util/app.h" -#include "util/args.h" -#include "util/threading.h" -#include "hwy/base.h" -#include "hwy/contrib/thread_pool/topology.h" +#include "gemma/gemma_args.h" +#include "ops/matmul.h" // MatMulEnv +#include "util/threading_context.h" #include "hwy/highway.h" -#include "hwy/per_target.h" // VectorBytes +#include "hwy/per_target.h" // DispatchedTarget #include "hwy/timer.h" namespace gcpp { @@ -54,11 +47,9 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) { } } -GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, - const AppArgs& app) - : topology_(CreateTopology(app)), - pools_(CreatePools(topology_, app)), - env_(topology_, pools_) { +GemmaEnv::GemmaEnv(const ThreadingArgs& threading_args, + const LoaderArgs& loader, const InferenceArgs& inference) + : env_(MakeMatMulEnv(threading_args)) { InferenceArgs mutable_inference = inference; AbortIfInvalidArgs(mutable_inference); LoaderArgs mutable_loader = loader; @@ -67,10 +58,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, fprintf(stderr, "Skipping model load because: %s\n", err); } else { fprintf(stderr, "Loading model...\n"); - model_ = AllocateGemma(mutable_loader, env_); + gemma_ = AllocateGemma(mutable_loader, env_); // Only allocate one for starters because GenerateBatch might not be called. kv_caches_.resize(1); - kv_caches_[0] = KVCache::Create(model_->GetModelConfig(), + kv_caches_[0] = KVCache::Create(gemma_->GetModelConfig(), inference.prefill_tbatch_size); } InitGenerator(inference, gen_); @@ -78,24 +69,13 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, .max_generated_tokens = inference.max_generated_tokens, .temperature = inference.temperature, .gen = &gen_, - .verbosity = app.verbosity, + .verbosity = inference.verbosity, }; } -// Internal init must run before the GemmaEnv ctor above, hence it cannot occur -// in the argv ctor below because its body runs *after* the delegating ctor. -// This helper function takes care of the init, and could be applied to any of -// the *Args classes, it does not matter which. -static AppArgs MakeAppArgs(int argc, char** argv) { - { // So that indentation matches expectations. - // Placeholder for internal init, do not modify. - } - return AppArgs(argc, argv); -} - GemmaEnv::GemmaEnv(int argc, char** argv) - : GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv), - MakeAppArgs(argc, argv)) {} + : GemmaEnv(ThreadingArgs(argc, argv), LoaderArgs(argc, argv), + InferenceArgs(argc, argv)) {} QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { QueryResult result; @@ -117,7 +97,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector& tokens) { } gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; runtime_config_.batch_stream_token = batch_stream_token; - model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], + gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], timing_info); return result; } @@ -127,7 +107,7 @@ void GemmaEnv::QueryModel( gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity }; const StreamFunc previous_stream_token = runtime_config_.stream_token; runtime_config_.stream_token = stream_token; - model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], + gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0], timing_info); runtime_config_.stream_token = previous_stream_token; } @@ -142,7 +122,7 @@ std::vector GemmaEnv::BatchQueryModel( int token, float) { std::string token_text; HWY_ASSERT( - model_->Tokenizer().Decode(std::vector{token}, &token_text)); + gemma_->Tokenizer().Decode(std::vector{token}, &token_text)); res[query_index].response.append(token_text); res[query_index].tokens_generated += 1; if (res[query_index].tokens_generated == @@ -164,7 +144,7 @@ std::vector GemmaEnv::BatchQueryModel( } for (size_t i = 1; i < num_queries; ++i) { if (kv_caches_[i].seq_len == 0) { - kv_caches_[i] = KVCache::Create(model_->GetModelConfig(), + kv_caches_[i] = KVCache::Create(gemma_->GetModelConfig(), runtime_config_.prefill_tbatch_size); } } @@ -172,7 +152,7 @@ std::vector GemmaEnv::BatchQueryModel( gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity}; runtime_config_.batch_stream_token = batch_stream_token; std::vector queries_pos(num_queries, 0); - model_->GenerateBatch(runtime_config_, queries_prompt, + gemma_->GenerateBatch(runtime_config_, queries_prompt, QueriesPos(queries_pos.data(), num_queries), KVCaches(&kv_caches_[0], num_queries), timing_info); return res; @@ -203,7 +183,7 @@ std::vector GemmaEnv::BatchQueryModel( float GemmaEnv::CrossEntropy(const std::string& input) { std::vector prompt = Tokenize(input); prompt.insert(prompt.begin(), BOS_ID); - return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt, + return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt, MutableKVCache(), /*verbosity=*/0) / static_cast(input.size()); @@ -236,17 +216,36 @@ std::string CacheString() { return buf; } -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - const BoundedTopology& topology, NestedPools& pools) { - loader.Print(app.verbosity); - inference.Print(app.verbosity); - app.Print(app.verbosity); +static constexpr const char* CompiledConfig() { + if constexpr (HWY_IS_ASAN) { + return "asan"; + } else if constexpr (HWY_IS_MSAN) { + return "msan"; + } else if constexpr (HWY_IS_TSAN) { + return "tsan"; + } else if constexpr (HWY_IS_HWASAN) { + return "hwasan"; + } else if constexpr (HWY_IS_UBSAN) { + return "ubsan"; + } else if constexpr (HWY_IS_DEBUG_BUILD) { + return "dbg"; + } else { + return "opt"; + } +} - if (app.verbosity >= 2) { +void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader, + InferenceArgs& inference) { + threading.Print(inference.verbosity); + loader.Print(inference.verbosity); + inference.Print(inference.verbosity); + + if (inference.verbosity >= 2) { time_t now = time(nullptr); char* dt = ctime(&now); // NOLINT char cpu100[100] = "unknown"; (void)hwy::platform::GetCpuString(cpu100); + const ThreadingContext2& ctx = ThreadingContext2::Get(); fprintf(stderr, "Date & Time : %s" // dt includes \n @@ -254,16 +253,18 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, "CPU topology : %s, %s, %s\n" "Instruction set : %s (%zu bits)\n" "Compiled config : %s\n" - "Weight Type : %s\n" - "EmbedderInput Type : %s\n", - dt, cpu100, topology.TopologyString(), pools.PinString(), + "Memory MiB : %4zu, %4zu free\n" + "Weight Type : %s\n", + dt, cpu100, ctx.topology.TopologyString(), ctx.pools.PinString(), CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()), - hwy::VectorBytes() * 8, CompiledConfig(), - StringFromType(loader.Info().weight), TypeName()); + ctx.allocator.VectorBytes() * 8, CompiledConfig(), + ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB(), + StringFromType(loader.Info().weight)); } } -void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { +void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader, + InferenceArgs& inference) { std::cerr << "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n" "==========================================================\n\n" @@ -272,16 +273,16 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { " --tokenizer\n" " --weights\n" " --model,\n" - " or with the newer weights format, specify just:\n" + " or with the single-file weights format, specify just:\n" " --weights\n"; std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm " "--weights 2b-it-sfp.sbs --model 2b-it\n"; + std::cerr << "\n*Threading Arguments*\n\n"; + threading.Help(); std::cerr << "\n*Model Loading Arguments*\n\n"; loader.Help(); std::cerr << "\n*Inference Arguments*\n\n"; inference.Help(); - std::cerr << "\n*Application Arguments*\n\n"; - app.Help(); std::cerr << "\n"; } diff --git a/evals/benchmark_helper.h b/evals/benchmark_helper.h index 9d4773a..8aaefe1 100644 --- a/evals/benchmark_helper.h +++ b/evals/benchmark_helper.h @@ -24,9 +24,9 @@ #include #include "gemma/gemma.h" +#include "gemma/gemma_args.h" #include "ops/matmul.h" -#include "util/app.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/base.h" namespace gcpp { @@ -46,8 +46,10 @@ class GemmaEnv { public: // Calls the other constructor with *Args arguments initialized from argv. GemmaEnv(int argc, char** argv); - GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference, - const AppArgs& app); + GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader, + const InferenceArgs& inference); + + MatMulEnv& Env() { return env_; } size_t MaxGeneratedTokens() const { return runtime_config_.max_generated_tokens; @@ -58,7 +60,7 @@ class GemmaEnv { std::vector Tokenize(const std::string& input) const { std::vector tokens; - HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens)); + HWY_ASSERT(gemma_->Tokenizer().Encode(input, &tokens)); return tokens; } @@ -69,13 +71,13 @@ class GemmaEnv { } std::vector WrapAndTokenize(std::string& input) const { - return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(), - model_->Info(), 0, input); + return gcpp::WrapAndTokenize(gemma_->Tokenizer(), gemma_->ChatTemplate(), + gemma_->Info(), 0, input); } std::string StringFromTokens(const std::vector& tokens) const { std::string string; - HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string)); + HWY_ASSERT(gemma_->Tokenizer().Decode(tokens, &string)); return string; } @@ -99,7 +101,7 @@ class GemmaEnv { float CrossEntropy(const std::string& input); // Returns nullptr if the model failed to load. - Gemma* GetModel() const { return model_.get(); } + Gemma* GetGemma() const { return gemma_.get(); } int Verbosity() const { return runtime_config_.verbosity; } RuntimeConfig& MutableConfig() { return runtime_config_; } @@ -107,11 +109,9 @@ class GemmaEnv { KVCache& MutableKVCache() { return kv_caches_[0]; } private: - BoundedTopology topology_; - NestedPools pools_; // Thread pool. MatMulEnv env_; std::mt19937 gen_; // Random number generator. - std::unique_ptr model_; + std::unique_ptr gemma_; std::vector kv_caches_; // Same number as query batch. RuntimeConfig runtime_config_; }; @@ -119,9 +119,10 @@ class GemmaEnv { // Logs the inference speed in tokens/sec. void LogSpeedStats(double time_start, size_t total_tokens); -void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app, - const BoundedTopology& topology, NestedPools& pools); -void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app); +void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader, + InferenceArgs& inference); +void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader, + InferenceArgs& inference); } // namespace gcpp diff --git a/evals/gemma_batch_bench.cc b/evals/gemma_batch_bench.cc index 44b803f..c92194c 100644 --- a/evals/gemma_batch_bench.cc +++ b/evals/gemma_batch_bench.cc @@ -51,8 +51,8 @@ class GemmaTest : public ::testing::Test { // Using the turn structure worsens results sometimes. // However, some models need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || - s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { + if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B || + s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) { for (QueryResult result : s_env->BatchQueryModel(inputs)) { replies.push_back(result.response); } @@ -76,7 +76,7 @@ class GemmaTest : public ::testing::Test { } void GenerateTokens(std::vector &kQA, size_t num_questions) { - ASSERT_NE(s_env->GetModel(), nullptr); + ASSERT_NE(s_env->GetGemma(), nullptr); std::vector inputs; for (size_t i = 0; i < num_questions; ++i) { diff --git a/evals/gemma_test.cc b/evals/gemma_test.cc index c73bec6..dcfffa2 100644 --- a/evals/gemma_test.cc +++ b/evals/gemma_test.cc @@ -50,8 +50,8 @@ class GemmaTest : public ::testing::Test { // Using the turn structure worsens results sometimes. // However, some models need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || - s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { + if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B || + s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) { std::string mutable_prompt = prompt; QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns. return result.response; @@ -71,8 +71,8 @@ class GemmaTest : public ::testing::Test { // Using the turn structure worsens results sometimes. // However, some models need the turn structure to work. // It would be good to make these tests more consistent. - if (s_env->GetModel()->Info().model == Model::GEMMA2_27B || - s_env->GetModel()->Info().model == Model::GRIFFIN_2B) { + if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B || + s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) { for (QueryResult result : s_env->BatchQueryModel(inputs)) { replies.push_back(result.response); } @@ -96,7 +96,7 @@ class GemmaTest : public ::testing::Test { } void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) { - ASSERT_NE(s_env->GetModel(), nullptr); + HWY_ASSERT(s_env->GetGemma() != nullptr); if (batch) { std::vector inputs; for (size_t i = 0; i < num_questions; ++i) { @@ -155,8 +155,8 @@ TEST_F(GemmaTest, Arithmetic) { } TEST_F(GemmaTest, Multiturn) { - Gemma* model = s_env->GetModel(); - ASSERT_NE(model, nullptr); + Gemma* model = s_env->GetGemma(); + HWY_ASSERT(model != nullptr); size_t abs_pos = 0; std::string response; auto stream_token = [&](int token, float) { @@ -239,12 +239,12 @@ static const char kGettysburg[] = { "people, for the people, shall not perish from the earth.\n"}; TEST_F(GemmaTest, CrossEntropySmall) { - ASSERT_NE(s_env->GetModel(), nullptr); + HWY_ASSERT(s_env->GetGemma() != nullptr); static const char kSmall[] = "The capital of Hungary is Budapest which is located in Europe."; float entropy = s_env->CrossEntropy(kSmall); fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (s_env->GetModel()->Info().model) { + switch (s_env->GetGemma()->Info().model) { case gcpp::Model::GEMMA_2B: // 2B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 2.6f, 0.2f); @@ -272,10 +272,10 @@ TEST_F(GemmaTest, CrossEntropySmall) { } TEST_F(GemmaTest, CrossEntropyJingleBells) { - ASSERT_NE(s_env->GetModel(), nullptr); + HWY_ASSERT(s_env->GetGemma() != nullptr); float entropy = s_env->CrossEntropy(kJingleBells); fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (s_env->GetModel()->Info().model) { + switch (s_env->GetGemma()->Info().model) { case gcpp::Model::GEMMA_2B: // 2B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 1.9f, 0.2f); @@ -303,10 +303,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) { } TEST_F(GemmaTest, CrossEntropyGettysburg) { - ASSERT_NE(s_env->GetModel(), nullptr); + HWY_ASSERT(s_env->GetGemma() != nullptr); float entropy = s_env->CrossEntropy(kGettysburg); fprintf(stderr, "per-token entropy: %f\n", entropy); - switch (s_env->GetModel()->Info().model) { + switch (s_env->GetGemma()->Info().model) { case gcpp::Model::GEMMA_2B: // 2B v.1 and v.1.1 produce slightly different results. EXPECT_NEAR(entropy, 1.1f, 0.1f); diff --git a/evals/run_mmlu.cc b/evals/run_mmlu.cc index 77c9dcd..a266d9d 100644 --- a/evals/run_mmlu.cc +++ b/evals/run_mmlu.cc @@ -89,7 +89,7 @@ void Run(GemmaEnv& env, JsonArgs& json) { "A", "B", "C", "D", // " A", " B", " C", " D", // "**", "**:", ":**", "The", "Answer", "is", ":", "."}; - const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings); + const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings); for (auto sample : json_data["samples"]) { const int id = sample["i"]; @@ -131,7 +131,7 @@ void Run(GemmaEnv& env, JsonArgs& json) { .verbosity = env.Verbosity(), .stream_token = stream_token, }; - env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0, + env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0, env.MutableKVCache(), timing_info); std::string output_string = env.StringFromTokens(predicted_token_ids); diff --git a/examples/hello_world/BUILD.bazel b/examples/hello_world/BUILD.bazel index 3160103..440e824 100644 --- a/examples/hello_world/BUILD.bazel +++ b/examples/hello_world/BUILD.bazel @@ -10,13 +10,11 @@ cc_binary( name = "hello_world", srcs = ["run.cc"], deps = [ - # Placeholder for internal dep, do not remove., - "//:app", "//:args", + "//:gemma_args", "//:gemma_lib", - "//:threading", + "//:threading_context", "//:tokenizer", "@highway//:hwy", - "@highway//:thread_pool", ], ) diff --git a/examples/hello_world/README.md b/examples/hello_world/README.md index f396c05..acfaa48 100644 --- a/examples/hello_world/README.md +++ b/examples/hello_world/README.md @@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for example: ```sh -./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +./hello_world --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it ``` Should print a greeting to the terminal: diff --git a/examples/hello_world/run.cc b/examples/hello_world/run.cc index 8f65b15..05ce222 100644 --- a/examples/hello_world/run.cc +++ b/examples/hello_world/run.cc @@ -23,23 +23,17 @@ #include #include -// Placeholder for internal header, do not modify. #include "gemma/gemma.h" +#include "gemma/gemma_args.h" // LoaderArgs #include "gemma/tokenizer.h" -#include "util/app.h" // LoaderArgs #include "util/args.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } - + gcpp::ThreadingArgs threading(argc, argv); gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); if (gcpp::HasHelp(argc, argv)) { loader.Help(); return 0; @@ -53,14 +47,14 @@ int main(int argc, char** argv) { for (int arg = 0; arg < argc; ++arg) { // Find a --reject flag and consume everything after it. if (strcmp(argv[arg], "--reject") == 0) { - while (++arg < argc) reject_tokens.insert(atoi(argv[arg])); + while (++arg < argc) { + reject_tokens.insert(atoi(argv[arg])); // NOLINT + } } } // Instantiate model and KV Cache - gcpp::BoundedTopology topology(gcpp::CreateTopology(app)); - gcpp::NestedPools pools = gcpp::CreatePools(topology, app); - gcpp::MatMulEnv env(topology, pools); + gcpp::MatMulEnv env(MakeMatMulEnv(threading)); gcpp::Gemma model = gcpp::CreateGemma(loader, env); gcpp::KVCache kv_cache = gcpp::KVCache::Create(model.GetModelConfig(), diff --git a/examples/simplified_gemma/BUILD.bazel b/examples/simplified_gemma/BUILD.bazel index bedb322..2678ada 100644 --- a/examples/simplified_gemma/BUILD.bazel +++ b/examples/simplified_gemma/BUILD.bazel @@ -10,10 +10,10 @@ cc_library( name = "gemma", hdrs = ["gemma.hpp"], deps = [ - "//:app", + "//:gemma_args", "//:gemma_lib", "//:ops", - "//:threading", + "//:threading_context", "//:tokenizer", "@highway//:hwy", ], @@ -24,15 +24,6 @@ cc_binary( srcs = ["run.cc"], deps = [ ":gemma", - # Placeholder for internal dep, do not remove., - "//:app", - "//:args", - "//:common", - "//:gemma_lib", - "//:ops", - "//:threading", - "//:tokenizer", - "@highway//:hwy", - "@highway//:thread_pool", + "//:gemma_args", ], ) diff --git a/examples/simplified_gemma/README.md b/examples/simplified_gemma/README.md index d8f9394..37b4f71 100644 --- a/examples/simplified_gemma/README.md +++ b/examples/simplified_gemma/README.md @@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for example: ```sh -./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it +./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it ``` Should print a greeting to the terminal: diff --git a/examples/simplified_gemma/gemma.hpp b/examples/simplified_gemma/gemma.hpp index 5047866..33bd9c0 100644 --- a/examples/simplified_gemma/gemma.hpp +++ b/examples/simplified_gemma/gemma.hpp @@ -24,39 +24,22 @@ #include #include "third_party/gemma_cpp/gemma/gemma.h" +#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs #include "third_party/gemma_cpp/gemma/tokenizer.h" #include "third_party/gemma_cpp/ops/matmul.h" -#include "third_party/gemma_cpp/util/app.h" // LoaderArgs -#include "third_party/gemma_cpp/util/threading.h" +#include "third_party/gemma_cpp/util/threading_context.h" #include "third_party/highway/hwy/base.h" class SimplifiedGemma { public: SimplifiedGemma(const gcpp::LoaderArgs& loader, - const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(), - const gcpp::AppArgs& app = gcpp::AppArgs()) + const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(), + const gcpp::InferenceArgs& inference = gcpp::InferenceArgs()) : loader_(loader), + threading_(threading), inference_(inference), - app_(app), - topology_(gcpp::CreateTopology(app_)), - pools_(gcpp::CreatePools(topology_, app_)), - env_(topology_, pools_), + env_(MakeMatMulEnv(threading_)), model_(gcpp::CreateGemma(loader_, env_)) { - Init(); - } - - SimplifiedGemma(int argc, char** argv) - : loader_(argc, argv, /*validate=*/true), - inference_(argc, argv), - app_(argc, argv), - topology_(gcpp::CreateTopology(app_)), - pools_(gcpp::CreatePools(topology_, app_)), - env_(topology_, pools_), - model_(gcpp::CreateGemma(loader_, env_)) { - Init(); - } - - void Init() { // Instantiate model and KV Cache kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(), inference_.prefill_tbatch_size); @@ -66,6 +49,11 @@ class SimplifiedGemma { gen_.seed(rd()); } + SimplifiedGemma(int argc, char** argv) + : SimplifiedGemma(gcpp::LoaderArgs(argc, argv, /*validate=*/true), + gcpp::ThreadingArgs(argc, argv), + gcpp::InferenceArgs(argc, argv)) {} + void Generate(std::string& prompt, size_t max_generated_tokens = 1024, float temperature = 0.7, const std::set& reject_tokens = {}) { @@ -107,10 +95,8 @@ class SimplifiedGemma { private: gcpp::LoaderArgs loader_; + gcpp::ThreadingArgs threading_; gcpp::InferenceArgs inference_; - gcpp::AppArgs app_; - gcpp::BoundedTopology topology_; - gcpp::NestedPools pools_; gcpp::MatMulEnv env_; gcpp::Gemma model_; gcpp::KVCache kv_cache_; diff --git a/examples/simplified_gemma/run.cc b/examples/simplified_gemma/run.cc index f73ddb5..0b7d865 100644 --- a/examples/simplified_gemma/run.cc +++ b/examples/simplified_gemma/run.cc @@ -17,15 +17,10 @@ #include -// Placeholder for internal header, do not modify. #include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp" -#include "util/app.h" // LoaderArgs +#include "gemma/gemma_args.h" // LoaderArgs int main(int argc, char** argv) { - { - // Placeholder for internal init, do not modify. - } - // Standard usage: LoaderArgs takes argc and argv as input, then parses // necessary flags. gcpp::LoaderArgs loader(argc, argv, /*validate=*/true); @@ -35,12 +30,12 @@ int main(int argc, char** argv) { // gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights", // "model_identifier"); - // Optional: InferenceArgs and AppArgs can be passed in as well. If not + // Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not // specified, default values will be used. // // gcpp::InferenceArgs inference(argc, argv); - // gcpp::AppArgs app(argc, argv); - // SimplifiedGemma gemma(loader, inference, app); + // gcpp::ThreadingArgs threading(argc, argv); + // SimplifiedGemma gemma(loader, threading, inference); SimplifiedGemma gemma(loader); std::string prompt = "Write a greeting to the world."; diff --git a/gemma/activations.h b/gemma/activations.h index 86345e2..89ca1f6 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -18,14 +18,12 @@ #include -#include "compression/shared.h" // BF16 -#include "gemma/configs.h" -#include "ops/matmul.h" // MatMulEnv -#include "ops/ops.h" // CreateInvTimescale -#include "util/allocator.h" // RowVectorBatch -#include "util/threading.h" -#include "hwy/base.h" // HWY_DASSERT -#include "hwy/contrib/thread_pool/thread_pool.h" +#include "gemma/configs.h" // ModelConfig +#include "ops/matmul.h" // MatMulEnv +#include "ops/ops.h" // CreateInvTimescale +#include "util/allocator.h" // Allocator +#include "util/basics.h" // BF16 +#include "util/mat.h" // RowVectorBatch namespace gcpp { @@ -74,6 +72,8 @@ struct Activations { size_t cache_pos_size = 0; void Allocate(size_t batch_size, MatMulEnv* env) { + const Allocator2& allocator = env->ctx.allocator; + post_qk = layer_config.post_qk; const size_t model_dim = weights_config.model_dim; const size_t ff_hidden_dim = layer_config.ff_hidden_dim; @@ -81,36 +81,45 @@ struct Activations { const size_t qkv_dim = layer_config.qkv_dim; const size_t heads = layer_config.heads; - x = RowVectorBatch(Extents2D(batch_size, model_dim)); + x = RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); q = RowVectorBatch( - Extents2D(batch_size, heads * layer_config.QStride())); + allocator, Extents2D(batch_size, heads * layer_config.QStride())); if (vocab_size > 0) { - logits = RowVectorBatch(Extents2D(batch_size, vocab_size)); + logits = + RowVectorBatch(allocator, Extents2D(batch_size, vocab_size)); } - pre_att_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); + pre_att_rms_out = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); att = RowVectorBatch( - Extents2D(batch_size, heads * weights_config.seq_len)); - att_out = RowVectorBatch(Extents2D(batch_size, heads * qkv_dim)); - att_sums = RowVectorBatch(Extents2D(batch_size, model_dim)); + allocator, Extents2D(batch_size, heads * weights_config.seq_len)); + att_out = RowVectorBatch(allocator, + Extents2D(batch_size, heads * qkv_dim)); + att_sums = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); - bf_pre_ffw_rms_out = RowVectorBatch(Extents2D(batch_size, model_dim)); - C1 = RowVectorBatch(Extents2D(batch_size, ff_hidden_dim)); - C2 = RowVectorBatch(Extents2D(batch_size, ff_hidden_dim)); - ffw_out = RowVectorBatch(Extents2D(batch_size, model_dim)); + bf_pre_ffw_rms_out = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); + C1 = RowVectorBatch(allocator, Extents2D(batch_size, ff_hidden_dim)); + C2 = RowVectorBatch(allocator, Extents2D(batch_size, ff_hidden_dim)); + ffw_out = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) { - griffin_x = RowVectorBatch(Extents2D(batch_size, model_dim)); - griffin_y = RowVectorBatch(Extents2D(batch_size, model_dim)); - griffin_gate_x = RowVectorBatch(Extents2D(batch_size, model_dim)); + griffin_x = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); + griffin_y = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); + griffin_gate_x = + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); griffin_multiplier = - RowVectorBatch(Extents2D(batch_size, model_dim)); + RowVectorBatch(allocator, Extents2D(batch_size, model_dim)); } - inv_timescale = CreateInvTimescale(layer_config.qkv_dim, + inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim, post_qk == PostQKType::HalfRope); - inv_timescale_global = - CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0); + inv_timescale_global = CreateInvTimescale( + allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0); this->env = env; } diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index ccb34f0..5f0c3cc 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -17,13 +17,13 @@ #include // sqrtf #include +#include #include #include // std::min #include #include -#include "compression/compress.h" #include "gemma/activations.h" #include "gemma/common.h" #include "gemma/configs.h" @@ -32,11 +32,9 @@ #include "gemma/weights.h" #include "paligemma/image.h" #include "util/allocator.h" -#include "util/basics.h" -#include "util/threading.h" -#include "hwy/aligned_allocator.h" -#include "hwy/base.h" -#include "hwy/bit_set.h" +#include "util/mat.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" #include "hwy/timer.h" @@ -82,7 +80,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, const KVCaches& kv_caches) { PROFILER_ZONE("Gen.Griffin"); KVCache& kv_cache = kv_caches[0]; - hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0); + hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; const size_t model_dim = layer_weights->layer_config.model_dim; @@ -96,8 +94,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, TwoMatVecAdd(layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, model_dim, model_dim, activations.pre_att_rms_out.Batch(batch_idx), - /*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(), - /*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(), + /*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(), + /*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(), /*out0=*/x, /*out1=*/y, pool); Gelu(y, model_dim); } @@ -121,15 +119,15 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) { auto xv = hn::Load(df, x + i); auto accum0 = - hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i); + hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i); auto accum1 = hn::Zero(df); HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even"); for (size_t l = 0; 2 * l < conv_1d_width; l++) { auto wv0 = - hn::Load(df, layer_weights->griffin.conv_w.data_scale1() + + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + (conv_1d_width - 1 - 2 * l) * model_dim + i); auto wv1 = - hn::Load(df, layer_weights->griffin.conv_w.data_scale1() + + hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() + (conv_1d_width - 2 - 2 * l) * model_dim + i); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); @@ -156,9 +154,9 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, TwoOfsMatVecAddLoop( layer_weights->griffin.gate_w, kMatrixSize * head, kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset, - /*add0=*/layer_weights->griffin.gate_biases.data_scale1() + + /*add0=*/layer_weights->griffin.gate_biases.PackedScale1() + head_offset, - /*add1=*/layer_weights->griffin.gate_biases.data_scale1() + + /*add1=*/layer_weights->griffin.gate_biases.PackedScale1() + model_dim + head_offset, /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); Sigmoid(gate_x + head_offset, kHeadDim); @@ -166,7 +164,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, const auto fn_mul = [](D d, hn::Vec x, hn::Vec gate_x) HWY_ATTR { return hn::Mul(x, gate_x); }; hn::Transform1(D(), a + head_offset, kHeadDim, - layer_weights->griffin.a.data_scale1() + head_offset, + layer_weights->griffin.a.PackedScale1() + head_offset, fn_mul); hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, fn_mul); @@ -198,7 +196,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens, float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx); float* out_ptr = activations.att_sums.Batch(batch_idx); MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x, - layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr, + layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr, pool); } } @@ -253,7 +251,7 @@ class GemmaAttention { const auto pre_att_rms_out = ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out); - auto w_q1 = layer_weights_.qkv_einsum_w.data() + auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr() ? ConstMatFromWeights(layer_weights_.qkv_einsum_w) : ConstMatFromWeights(layer_weights_.qkv_einsum_w1); // The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim, @@ -265,15 +263,20 @@ class GemmaAttention { const size_t w1_rows = heads * layer_config_.QStride(); w_q1.ShrinkRows(w1_rows); MatMul(pre_att_rms_out, w_q1, - /*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q)); + /*add=*/nullptr, *activations_.env, + RowPtrFromBatch(allocator_, activations_.q)); if (is_mha_) { // Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already. } else { - auto w_q2 = layer_weights_.qkv_einsum_w.data() - ? ConstMatFromWeights(layer_weights_.qkv_einsum_w, - w1_rows * model_dim) - : ConstMatFromWeights(layer_weights_.qkv_einsum_w2); + decltype(w_q1) w_q2; + if (layer_weights_.qkv_einsum_w.HasPtr()) { + w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w); + // Skip first half of the matrix. + w_q2.ofs = w_q2.Row(w1_rows); + } else { + w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2); + } // KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v). const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim; w_q2.ShrinkRows(w_rows_kv_cols); @@ -285,7 +288,7 @@ class GemmaAttention { const size_t kv_ofs = queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_; float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs; - RowPtrF kv_rows(kv, w_rows_kv_cols); + RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols); kv_rows.SetStride(cache_pos_size_); MatMul(pre_att_rms_out, w_q2, /*add=*/nullptr, *activations_.env, kv_rows); @@ -302,7 +305,7 @@ class GemmaAttention { const size_t kv_offset = cache_pos * cache_pos_size_ + layer_ * cache_layer_size_; float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; - if (layer_weights_.qkv_einsum_w.data()) { + if (layer_weights_.qkv_einsum_w.HasPtr()) { MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim, w_rows_kv_cols, model_dim, x, kv, pool_); } else { @@ -336,8 +339,8 @@ class GemmaAttention { } // Apply further processing to K. - if (layer_weights_.key_norm_scale.data()) { - RMSNormInplace(layer_weights_.key_norm_scale.data(), kv, + if (layer_weights_.key_norm_scale.HasPtr()) { + RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv, qkv_dim); } PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f); @@ -427,8 +430,8 @@ class GemmaAttention { // Apply rope and scaling to Q. const size_t pos = queries_pos_[query_idx] + batch_idx; - if (layer_weights_.query_norm_scale.data()) { - RMSNormInplace(layer_weights_.query_norm_scale.data(), q, + if (layer_weights_.query_norm_scale.HasPtr()) { + RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q, qkv_dim); } PositionalEncodingQK(q, pos, layer_, query_scale); @@ -473,17 +476,18 @@ class GemmaAttention { HWY_DASSERT(layer_config_.model_dim > 0); HWY_DASSERT(layer_config_.heads > 0); HWY_DASSERT(layer_config_.qkv_dim > 0); - HWY_DASSERT(layer_weights_.att_weights.data() != nullptr); + HWY_DASSERT(layer_weights_.att_weights.HasPtr()); HWY_DASSERT(activations_.att_out.All() != nullptr); HWY_DASSERT(activations_.att_sums.All() != nullptr); const float* add = layer_weights_.layer_config.softmax_attn_output_biases - ? layer_weights_.attention_output_biases.data_scale1() + ? layer_weights_.attention_output_biases.PackedScale1() : nullptr; MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out), ConstMatFromWeights(layer_weights_.att_weights), add, - *activations_.env, RowPtrFromBatch(activations_.att_sums)); + *activations_.env, + RowPtrFromBatch(allocator_, activations_.att_sums)); } public: @@ -533,7 +537,8 @@ class GemmaAttention { layer_weights_(*layer_weights), div_seq_len_(div_seq_len), kv_caches_(kv_caches), - pool_(activations.env->parallel.Pools().Pool(0)) { + allocator_(activations.env->ctx.allocator), + pool_(activations.env->ctx.pools.Pool(0)) { HWY_DASSERT(num_queries_ <= kv_caches_.size()); HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0, "query heads must be a multiple of key-value heads"); @@ -562,6 +567,7 @@ class GemmaAttention { const LayerWeightsPtrs& layer_weights_; const hwy::Divisor& div_seq_len_; const KVCaches& kv_caches_; + const Allocator2& allocator_; hwy::ThreadPool& pool_; }; @@ -606,8 +612,8 @@ class VitAttention { HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim); MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out), ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w), - layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env, - RowPtrFromBatch(qkv)); + layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env, + RowPtrFromBatch(allocator_, qkv)); } // TODO(philculliton): transition fully to MatMul. @@ -621,10 +627,10 @@ class VitAttention { // Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents) RowVectorBatch Q = - AllocateAlignedRows(Extents2D(num_tokens_, qkv_dim)); + AllocateAlignedRows(allocator_, Extents2D(num_tokens_, qkv_dim)); RowVectorBatch K = - AllocateAlignedRows(Extents2D(seq_len, qkv_dim)); - RowVectorBatch C(Extents2D(num_tokens_, seq_len)); + AllocateAlignedRows(allocator_, Extents2D(seq_len, qkv_dim)); + RowVectorBatch C(allocator_, Extents2D(num_tokens_, seq_len)); // Initialize att_out to zero prior to head loop. hwy::ZeroBytes(activations_.att_out.All(), @@ -650,7 +656,7 @@ class VitAttention { // this produces C, a (num_tokens_, seq_len) matrix of dot products MatMul(ConstMatFromBatch(Q.BatchSize(), Q), ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env, - RowPtrFromBatch(C)); + RowPtrFromBatch(allocator_, C)); pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR { float* HWY_RESTRICT c = C.Batch(task); @@ -712,13 +718,13 @@ class VitAttention { // head_dim (`qkv_dim`) into output (`att_sums`). HWY_NOINLINE void SumHeads() { PROFILER_ZONE("Gen.VitAttention.SumHeads"); - auto* bias = layer_weights_.vit.attn_out_b.data_scale1(); + auto* bias = layer_weights_.vit.attn_out_b.PackedScale1(); // att_weights and att_out are concatenated heads, each of length // qkv_dim. Thus the [num_tokens_, layer_config_.model_dim] // matmul output is the sum over heads. auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out); auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w); - auto att_sums = RowPtrFromBatch(activations_.att_sums); + auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums); MatMul(att_out, att_weights, bias, *activations_.env, att_sums); } @@ -730,7 +736,8 @@ class VitAttention { activations_(activations), layer_weights_(*layer_weights), layer_config_(layer_weights->layer_config), - pool_(activations.env->parallel.Pools().Pool(0)) {} + allocator_(activations.env->ctx.allocator), + pool_(activations.env->ctx.pools.Pool(0)) {} HWY_INLINE void operator()() { ComputeQKV(); @@ -748,6 +755,7 @@ class VitAttention { Activations& activations_; const LayerWeightsPtrs& layer_weights_; const LayerConfig& layer_config_; + const Allocator2& allocator_; hwy::ThreadPool& pool_; }; @@ -779,32 +787,35 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved, const bool add_bias = layer_weights->layer_config.ff_biases; const float* bias1 = - add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr; + add_bias ? layer_weights->ffw_gating_biases.PackedScale1() : nullptr; const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr; const float* output_bias = - add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr; + add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr; // Define slightly more readable names for the weights and activations. const auto x = ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); - auto hidden_activations = RowPtrFromBatch(activations.C1); - auto multiplier = RowPtrFromBatch(activations.C2); - auto ffw_out = RowPtrFromBatch(activations.ffw_out); + const Allocator2& allocator = activations.env->ctx.allocator; + auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); + auto multiplier = RowPtrFromBatch(allocator, activations.C2); + auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); // gating_einsum_w holds two half-matrices. We plan to change the importer to // avoid this confusion by splitting into gating_einsum_w1 and // gating_einsum_w2. - const bool split = !!layer_weights->gating_einsum_w.data(); + const bool split = layer_weights->gating_einsum_w.HasPtr(); auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w) : ConstMatFromWeights(layer_weights->gating_einsum_w1); - auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w, - model_dim * ffh_hidden_dim) - : ConstMatFromWeights(layer_weights->gating_einsum_w2); + decltype(w1) w2; if (split) { + w2 = ConstMatFromWeights(layer_weights->gating_einsum_w); + w2.ofs = w2.Row(ffh_hidden_dim); // Ensure that B.Extents().row matches C.Cols() because MatMul checks that. w1.ShrinkRows(ffh_hidden_dim); w2.ShrinkRows(ffh_hidden_dim); + } else { + w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2); } auto w_output = ConstMatFromWeights(layer_weights->linear_w); @@ -835,16 +846,17 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, const bool add_bias = layer_weights->layer_config.ff_biases; const float* bias1 = - add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr; + add_bias ? layer_weights->vit.linear_0_b.PackedScale1() : nullptr; const float* output_bias = - add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr; + add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr; // Define slightly more readable names for the weights and activations. const auto x = ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out); - auto hidden_activations = RowPtrFromBatch(activations.C1); - auto ffw_out = RowPtrFromBatch(activations.ffw_out); + const Allocator2& allocator = activations.env->ctx.allocator; + auto hidden_activations = RowPtrFromBatch(allocator, activations.C1); + auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out); auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w); auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w); @@ -853,7 +865,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved, MatMul(x, w1, bias1, *activations.env, hidden_activations); // Activation (Gelu), store in act. - RowPtrF multiplier = RowPtrF(nullptr, 0); + RowPtrF multiplier = RowPtrF(allocator, nullptr, 0); Activation(layer_weights->layer_config.activation, hidden_activations.Row(0), multiplier.Row(0), ff_hidden_dim * num_interleaved); @@ -905,11 +917,9 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos, HWY_DASSERT(token < static_cast(vocab_size)); const hn::ScalableTag df; - DecompressAndZeroPad( - df, - MakeSpan(weights.embedder_input_embedding.data(), vocab_size * model_dim), - token * model_dim, x.Batch(batch_idx), model_dim); - MulByConst(emb_scaling * weights.embedder_input_embedding.scale(), + DecompressAndZeroPad(df, weights.embedder_input_embedding.Span(), + token * model_dim, x.Batch(batch_idx), model_dim); + MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(), x.Batch(batch_idx), model_dim); if (weights.weights_config.absolute_pe) { AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos); @@ -943,9 +953,10 @@ HWY_NOINLINE void ResidualConnection( template void PostNorm(PostNormType post_norm, size_t num_interleaved, const WeightT& weights, InOutT* inout) { + HWY_DASSERT(weights.Rows() == 1); if (post_norm == PostNormType::Scale) { - RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout, - weights.NumElements()); + RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout, + weights.Cols()); } } @@ -962,7 +973,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, auto type = layer_weights->layer_config.type; RMSNormBatched(num_interleaved, activations.x.All(), - layer_weights->pre_attention_norm_scale.data_scale1(), + layer_weights->pre_attention_norm_scale.PackedScale1(), activations.pre_att_rms_out.All(), model_dim); Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx, @@ -976,7 +987,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos, activations.x.All(), layer_weights, /*is_attention=*/true); RMSNormBatched(num_interleaved, activations.x.All(), - layer_weights->pre_ffw_norm_scale.data_scale1(), + layer_weights->pre_ffw_norm_scale.PackedScale1(), activations.bf_pre_ffw_rms_out.All(), model_dim); if (layer_weights->layer_config.type == LayerAttentionType::kVit) { @@ -1014,8 +1025,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, // y = nn.LayerNorm()(x) // y ~ pre_att_rms_out LayerNormBatched(num_tokens, x.All(), - layer_weights->vit.layer_norm_0_scale.data_scale1(), - layer_weights->vit.layer_norm_0_bias.data_scale1(), + layer_weights->vit.layer_norm_0_scale.PackedScale1(), + layer_weights->vit.layer_norm_0_bias.PackedScale1(), activations.pre_att_rms_out.All(), model_dim); // y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y) @@ -1028,8 +1039,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer, // y = nn.LayerNorm()(x) // y ~ bf_pre_ffw_rms_out LayerNormBatched(num_tokens, x.All(), - layer_weights->vit.layer_norm_1_scale.data_scale1(), - layer_weights->vit.layer_norm_1_bias.data_scale1(), + layer_weights->vit.layer_norm_1_scale.PackedScale1(), + layer_weights->vit.layer_norm_1_bias.PackedScale1(), activations.bf_pre_ffw_rms_out.All(), model_dim); // y = out["mlp"] = MlpBlock(...)(y) @@ -1161,8 +1172,8 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, const size_t patch_width = weights.weights_config.vit_config.patch_width; const size_t seq_len = weights.weights_config.vit_config.seq_len; const size_t patch_size = patch_width * patch_width * 3; - HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() == - patch_size * model_dim); + HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim); + HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size); HWY_DASSERT(activations.x.Cols() == model_dim); std::vector> image_patches(seq_len); for (size_t i = 0; i < seq_len; ++i) { @@ -1178,20 +1189,20 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image, // MatMul( // MatFromBatch(kVitSeqLen, image_patches), // MatFromWeights(weights.vit_img_embedding_kernel), - // weights.vit_img_embedding_bias.data_scale1(), *activations.env, + // weights.vit_img_embedding_bias.PackedScale1(), *activations.env, // RowPtrF(activations.x.All(), kVitModelDim)); // However, MatMul currently requires that // A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0 // which is not the case here. We should relax that requirement on MatMul and // then use the above. For now, we rely on MatVecAdd instead. for (size_t i = 0; i < seq_len; ++i) { - MatVecAdd( - weights.vit_img_embedding_kernel, 0, model_dim, patch_size, - image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(), - activations.x.Batch(i), activations.env->parallel.Pools().Pool(0)); + MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size, + image_patches[i].get(), + weights.vit_img_embedding_bias.PackedScale1(), + activations.x.Batch(i), activations.env->ctx.pools.Pool(0)); } // Add position embeddings. - AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(), + AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(), seq_len * model_dim); } @@ -1216,23 +1227,23 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs& weights, } // Final Layernorm. LayerNormBatched(num_tokens, activations.x.All(), - weights.vit_encoder_norm_scale.data_scale1(), - weights.vit_encoder_norm_bias.data_scale1(), + weights.vit_encoder_norm_scale.PackedScale1(), + weights.vit_encoder_norm_bias.PackedScale1(), activations.x.All(), vit_model_dim); if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) { activations.x = AvgPool4x4(activations.x); // Apply soft embedding norm before input projection. - RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(), + RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(), vit_model_dim); } // Apply head embedding into image_tokens of size of the LLM kModelDim. MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x), ConstMatFromWeights(weights.vit_img_head_kernel), - weights.vit_img_head_bias.data_scale1(), *activations.env, - RowPtrFromBatch(image_tokens)); + weights.vit_img_head_bias.PackedScale1(), *activations.env, + RowPtrFromBatch(activations.env->ctx.allocator, image_tokens)); } // Generates one token for each query. `queries_token` is the previous token @@ -1274,7 +1285,7 @@ HWY_NOINLINE void Transformer( } } - RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(), + RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(), activations.x.All(), model_dim); if (activations_observer) { @@ -1374,7 +1385,7 @@ bool DecodeStepT(const ModelWeightsPtrs& weights, MatMul(ConstMatFromBatch(num_queries, activations.x), ConstMatFromWeights(weights.embedder_input_embedding), /*add=*/nullptr, *activations.env, - RowPtrFromBatch(activations.logits)); + RowPtrFromBatch(activations.env->ctx.allocator, activations.logits)); } PROFILER_ZONE("Gen.Softcap+Sample+Stream"); for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) { diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 658ff66..51cf5f4 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -27,22 +27,33 @@ #include // std::move #include -#include "compression/io.h" // Path +// Placeholder for internal header, do not modify. #include "compression/shared.h" #include "gemma/common.h" +#include "gemma/configs.h" +#include "gemma/tokenizer.h" #include "gemma/weights.h" -#include "ops/ops-inl.h" +#include "ops/matmul.h" #include "paligemma/image.h" -#include "util/threading.h" -#include "hwy/highway.h" +#include "util/threading_context.h" +#include "hwy/base.h" namespace gcpp { +// Internal init must run before I/O; calling it from `GemmaEnv()` is too late. +// This helper function takes care of the internal init plus calling `SetArgs`. +MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) { + // Placeholder for internal init, do not modify. + + ThreadingContext2::SetArgs(threading_args); + return MatMulEnv(ThreadingContext2::Get()); +} + Gemma::Gemma(const Path& tokenizer_path, const Path& weights, const ModelInfo& info, MatMulEnv& env) : env_(env), tokenizer_(tokenizer_path) { model_.Load(weights, info.model, info.weight, info.wrapping, - env_.parallel.Pools().Pool(0), + env_.ctx.pools.Pool(0), /*tokenizer_proto=*/nullptr); chat_template_.Init(tokenizer_, model_.Config().model); } @@ -50,7 +61,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) { std::string tokenizer_proto; model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT, - env_.parallel.Pools().Pool(0), &tokenizer_proto); + env_.ctx.pools.Pool(0), &tokenizer_proto); tokenizer_.Deserialize(tokenizer_proto); chat_template_.Init(tokenizer_, model_.Config().model); } @@ -60,7 +71,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env) tokenizer_(std::move(tokenizer)), chat_template_(tokenizer_, info.model) { HWY_ASSERT(info.weight == Type::kF32); - model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0)); + model_.Allocate(info.model, info.weight, env_.ctx.pools.Pool(0)); } Gemma::~Gemma() { @@ -130,12 +141,12 @@ struct GenerateImageTokensT { void Gemma::Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt, size_t pos, size_t prefix_end, KVCache& kv_cache, TimingInfo& timing_info) { - env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); + env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight( runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info); - env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); + env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, @@ -152,23 +163,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config, QueriesPos(prefix_end_vec.data(), prefix_end_vec.size()); } - env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); + env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight( runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end, kv_caches, &env_, timing_info); - env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); + env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config, const Image& image, ImageTokens& image_tokens) { - env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning); + env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning); model_.CallForModelWeight(runtime_config, image, image_tokens, &env_); - env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning); + env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning); } // Non-template functions moved from gemma-inl.h to avoid ODR violations. diff --git a/gemma/gemma.h b/gemma/gemma.h index de0cba1..77cdf58 100644 --- a/gemma/gemma.h +++ b/gemma/gemma.h @@ -16,6 +16,8 @@ #ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ #define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_ +#include + #include #include #include @@ -31,8 +33,9 @@ #include "gemma/weights.h" #include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" -#include "util/allocator.h" // RowVectorBatch -#include "util/basics.h" // TokenAndProb +#include "util/basics.h" // TokenAndProb +#include "util/mat.h" // RowVectorBatch +#include "util/threading_context.h" #include "hwy/timer.h" // IWYU pragma: end_exports #include "hwy/aligned_allocator.h" // Span @@ -193,6 +196,10 @@ struct TimingInfo { size_t tokens_generated = 0; }; +// Internal init must run before I/O; calling it from GemmaEnv() is too late. +// This helper function takes care of the internal init plus calling `SetArgs`. +MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args); + class Gemma { public: // Reads old format weights file and tokenizer file. @@ -206,7 +213,9 @@ class Gemma { Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env); ~Gemma(); + MatMulEnv& Env() const { return env_; } const ModelConfig& GetModelConfig() const { return model_.Config(); } + // DEPRECATED ModelInfo Info() const { return ModelInfo({.model = model_.Config().model, .wrapping = model_.Config().wrapping, diff --git a/util/app.h b/gemma/gemma_args.h similarity index 71% rename from util/app.h rename to gemma/gemma_args.h index a66dd3d..4fe2d33 100644 --- a/util/app.h +++ b/gemma/gemma_args.h @@ -15,8 +15,8 @@ // Shared between various frontends. -#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ -#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ +#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ #include #include @@ -31,103 +31,10 @@ #include "ops/matmul.h" #include "util/args.h" #include "util/basics.h" // Tristate -#include "util/threading.h" -#include "hwy/base.h" // HWY_IS_ASAN +#include "hwy/base.h" // HWY_ABORT namespace gcpp { -static inline const char* CompiledConfig() { - if (HWY_IS_ASAN) { - return "asan"; - } else if (HWY_IS_MSAN) { - return "msan"; - } else if (HWY_IS_TSAN) { - return "tsan"; - } else if (HWY_IS_HWASAN) { - return "hwasan"; - } else if (HWY_IS_UBSAN) { - return "ubsan"; - } else if (HWY_IS_DEBUG_BUILD) { - return "dbg"; - } else { - return "opt"; - } -} - -class AppArgs : public ArgsBase { - public: - AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } - AppArgs() { Init(); }; - - int verbosity; - - size_t max_threads; // divided among the detected clusters - Tristate pin; // pin threads? - Tristate spin; // use spin waits? - - // For BoundedSlice: - size_t skip_packages; - size_t max_packages; - size_t skip_clusters; - size_t max_clusters; - size_t skip_lps; - size_t max_lps; - - std::string eot_line; - - template - void ForEach(const Visitor& visitor) { - visitor(verbosity, "verbosity", 1, - "Show verbose developer information\n 0 = only print generation " - "output\n 1 = standard user-facing terminal ui\n 2 = show " - "developer/debug info).\n Default = 1.", - 2); - - // The exact meaning is more subtle: see the comment at NestedPools ctor. - visitor(max_threads, "num_threads", size_t{0}, - "Maximum number of threads to use; default 0 = unlimited.", 2); - visitor(pin, "pin", Tristate::kDefault, - "Pin threads? -1 = auto, 0 = no, 1 = yes.", 2); - visitor(spin, "spin", Tristate::kDefault, - "Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2); - // These can be used to partition CPU sockets/packages and their - // clusters/CCXs across several program instances. The default is to use - // all available resources. - visitor(skip_packages, "skip_packages", size_t{0}, - "Index of the first socket to use; default 0 = unlimited.", 2); - visitor(max_packages, "max_packages", size_t{0}, - "Maximum number of sockets to use; default 0 = unlimited.", 2); - visitor(skip_clusters, "skip_clusters", size_t{0}, - "Index of the first CCX to use; default 0 = unlimited.", 2); - visitor(max_clusters, "max_clusters", size_t{0}, - "Maximum number of CCXs to use; default 0 = unlimited.", 2); - // These are only used when CPU topology is unknown. - visitor(skip_lps, "skip_lps", size_t{0}, - "Index of the first LP to use; default 0 = unlimited.", 2); - visitor(max_lps, "max_lps", size_t{0}, - "Maximum number of LPs to use; default 0 = unlimited.", 2); - - visitor( - eot_line, "eot_line", std::string(""), - "End of turn line. " - "When you specify this, the prompt will be all lines " - "before the line where only the given string appears.\n Default = " - "When a newline is encountered, that signals the end of the turn.", - 2); - } -}; - -static inline BoundedTopology CreateTopology(const AppArgs& app) { - return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages), - BoundedSlice(app.skip_clusters, app.max_clusters), - BoundedSlice(app.skip_lps, app.max_lps)); -} -static inline NestedPools CreatePools(const BoundedTopology& topology, - const AppArgs& app) { - Allocator::Init(topology); - return NestedPools(topology, app.max_threads, app.pin); -} - struct LoaderArgs : public ArgsBase { LoaderArgs(int argc, char* argv[], bool validate = true) { InitAndParse(argc, argv); @@ -154,15 +61,6 @@ struct LoaderArgs : public ArgsBase { // Returns error string or nullptr if OK. const char* Validate() { - if (!compressed_weights.path.empty()) { - if (weights.path.empty()) { - weights = compressed_weights; - } else { - return "Only one of --weights and --compressed_weights can be " - "specified. To create compressed weights use the " - "compress_weights tool."; - } - } if (weights.path.empty()) { return "Missing --weights flag, a file for the model weights."; } @@ -250,6 +148,8 @@ struct InferenceArgs : public ArgsBase { InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); } InferenceArgs() { Init(); }; + int verbosity; + size_t max_generated_tokens; size_t prefill_tbatch_size; @@ -261,6 +161,8 @@ struct InferenceArgs : public ArgsBase { bool multiturn; Path image_file; + std::string eot_line; + // Returns error string or nullptr if OK. const char* Validate() const { if (max_generated_tokens > gcpp::kSeqLen) { @@ -272,6 +174,12 @@ struct InferenceArgs : public ArgsBase { template void ForEach(const Visitor& visitor) { + visitor(verbosity, "verbosity", 1, + "Show verbose developer information\n 0 = only print generation " + "output\n 1 = standard user-facing terminal ui\n 2 = show " + "developer/debug info).\n Default = 1.", + 2); + visitor(max_generated_tokens, "max_generated_tokens", size_t{2048}, "Maximum number of tokens to generate."); @@ -291,6 +199,14 @@ struct InferenceArgs : public ArgsBase { " Default : 0 (conversation " "resets every turn)"); visitor(image_file, "image_file", Path(), "Image file to load."); + + visitor( + eot_line, "eot_line", std::string(""), + "End of turn line. " + "When you specify this, the prompt will be all lines " + "before the line where only the given string appears.\n Default = " + "When a newline is encountered, that signals the end of the turn.", + 2); } void CopyTo(RuntimeConfig& runtime_config) const { @@ -317,4 +233,4 @@ struct InferenceArgs : public ArgsBase { } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_ +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_ diff --git a/gemma/run.cc b/gemma/run.cc index a437ae0..5170b6e 100644 --- a/gemma/run.cc +++ b/gemma/run.cc @@ -15,23 +15,23 @@ // Command line text interface to gemma. +#include + #include #include #include #include #include -// Placeholder for internal header, do not modify. #include "compression/shared.h" // PromptWrapping #include "evals/benchmark_helper.h" #include "gemma/common.h" #include "gemma/gemma.h" // Gemma -#include "ops/matmul.h" // MatMulEnv +#include "gemma/gemma_args.h" // LoaderArgs +#include "ops/matmul.h" // MatMulEnv #include "paligemma/image.h" -#include "util/app.h" #include "util/args.h" // HasHelp -#include "util/threading.h" -#include "hwy/base.h" +#include "util/threading_context.h" #include "hwy/highway.h" #include "hwy/profiler.h" @@ -78,35 +78,37 @@ std::string GetPrompt(std::istream& input, int verbosity, } // The main Read-Eval-Print Loop. -void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, - const InferenceArgs& args, const AcceptFunc& accept_token, - std::string& eot_line) { +void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference, + Gemma& model, KVCache& kv_cache) { PROFILER_ZONE("Gen.misc"); size_t abs_pos = 0; // across turns size_t tokens_generated_this_turn = 0; // differentiates prefill from reply size_t prompt_size = 0; std::mt19937 gen; - InitGenerator(args, gen); + InitGenerator(inference, gen); - const bool have_image = !args.image_file.path.empty(); + const bool have_image = !inference.image_file.path.empty(); Image image; ImageTokens image_tokens; if (have_image) { size_t pool_dim = model.GetModelConfig().vit_config.pool_dim; - image_tokens = ImageTokens(Extents2D( - model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim), - model.GetModelConfig().model_dim)); + image_tokens = + ImageTokens(model.Env().ctx.allocator, + Extents2D(model.GetModelConfig().vit_config.seq_len / + (pool_dim * pool_dim), + model.GetModelConfig().model_dim)); HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA || model.Info().wrapping == PromptWrapping::GEMMA_VLM); - HWY_ASSERT(image.ReadPPM(args.image_file.path)); + HWY_ASSERT(image.ReadPPM(inference.image_file.path)); const size_t image_size = model.GetModelConfig().vit_config.image_size; image.Resize(image_size, image_size); - RuntimeConfig runtime_config = { - .gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin}; + RuntimeConfig runtime_config = {.gen = &gen, + .verbosity = inference.verbosity, + .use_spinning = threading.spin}; double image_tokens_start = hwy::platform::Now(); model.GenerateImageTokens(runtime_config, image, image_tokens); - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { double image_tokens_duration = hwy::platform::Now() - image_tokens_start; fprintf(stderr, "\n\n[ Timing info ] Image token generation took: %d ms\n", @@ -121,12 +123,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, const bool first_response_token = tokens_generated_this_turn == prompt_size; ++tokens_generated_this_turn; if (in_prompt) { - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::cerr << "." << std::flush; } return true; } else if (model.GetModelConfig().IsEOS(token)) { - if (app.verbosity >= 2) { + if (inference.verbosity >= 2) { std::cout << "\n[ End ]\n"; } return true; @@ -135,7 +137,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, HWY_ASSERT(model.Tokenizer().Decode(std::vector{token}, &token_text)); if (first_response_token) { token_text.erase(0, token_text.find_first_not_of(" \t\n")); - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::cout << "\n\n"; } } @@ -147,7 +149,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, tokens_generated_this_turn = 0; // Read prompt and handle special commands. - std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line); + std::string prompt_string = + GetPrompt(std::cin, inference.verbosity, inference.eot_line); if (!std::cin) return; // If !eot_line.empty(), we append \n, so only look at the first 2 chars. if (prompt_string.size() >= 2 && prompt_string[0] == '%') { @@ -163,13 +166,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } // Set up runtime config. - TimingInfo timing_info = {.verbosity = app.verbosity}; + TimingInfo timing_info = {.verbosity = inference.verbosity}; RuntimeConfig runtime_config = {.gen = &gen, - .verbosity = app.verbosity, + .verbosity = inference.verbosity, .stream_token = stream_token, - .accept_token = accept_token, - .use_spinning = app.spin}; - args.CopyTo(runtime_config); + .use_spinning = threading.spin}; + inference.CopyTo(runtime_config); size_t prefix_end = 0; std::vector prompt; @@ -197,7 +199,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } // Generate until EOS or max_generated_tokens. - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::cerr << "\n[ Reading prompt ] " << std::flush; } model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache, @@ -205,9 +207,10 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, std::cout << "\n\n"; // Prepare for the next turn. Works only for PaliGemma. - if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) { + if (!inference.multiturn || + model.Info().wrapping == PromptWrapping::PALIGEMMA) { abs_pos = 0; // Start a new turn at position 0. - InitGenerator(args, gen); + InitGenerator(inference, gen); } else { // The last token was either EOS, then it should be ignored because it is // never part of the dialog, see Table 5 in the Gemma-2 paper: @@ -223,20 +226,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app, } } -void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { +void Run(ThreadingArgs& threading, LoaderArgs& loader, + InferenceArgs& inference) { PROFILER_ZONE("Run.misc"); // Note that num_threads is an upper bound; we also limit to the number of // detected and enabled cores. - const BoundedTopology topology = CreateTopology(app); - NestedPools pools = CreatePools(topology, app); - MatMulEnv env(topology, pools); - if (app.verbosity >= 2) env.print_best = true; + MatMulEnv env(MakeMatMulEnv(threading)); + if (inference.verbosity >= 2) env.print_best = true; Gemma model = CreateGemma(loader, env); KVCache kv_cache = KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size); - if (app.verbosity >= 1) { + if (inference.verbosity >= 1) { std::string instructions = "*Usage*\n" " Enter an instruction and press enter (%C resets conversation, " @@ -259,11 +261,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) { std::cout << "\033[2J\033[1;1H" // clear screen << kAsciiArtBanner << "\n\n"; - ShowConfig(loader, inference, app, topology, pools); + ShowConfig(threading, loader, inference); std::cout << "\n" << instructions << "\n"; } - ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line); + ReplGemma(threading, inference, model, kv_cache); } } // namespace gcpp @@ -272,31 +274,29 @@ int main(int argc, char** argv) { { PROFILER_ZONE("Startup.misc"); - // Placeholder for internal init, do not modify. - + gcpp::ThreadingArgs threading(argc, argv); gcpp::LoaderArgs loader(argc, argv); gcpp::InferenceArgs inference(argc, argv); - gcpp::AppArgs app(argc, argv); if (gcpp::HasHelp(argc, argv)) { std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, inference, app); + gcpp::ShowHelp(threading, loader, inference); return 0; } if (const char* error = loader.Validate()) { std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, inference, app); + gcpp::ShowHelp(threading, loader, inference); HWY_ABORT("\nInvalid args: %s", error); } if (const char* error = inference.Validate()) { std::cerr << gcpp::kAsciiArtBanner; - gcpp::ShowHelp(loader, inference, app); + gcpp::ShowHelp(threading, loader, inference); HWY_ABORT("\nInvalid args: %s", error); } - gcpp::Run(loader, inference, app); + gcpp::Run(threading, loader, inference); } PROFILER_PRINT_RESULTS(); // Must call outside the zone above. return 0; diff --git a/gemma/tensor_index.cc b/gemma/tensor_index.cc index 4308c9d..db37218 100644 --- a/gemma/tensor_index.cc +++ b/gemma/tensor_index.cc @@ -562,11 +562,12 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx, if (llm_layer_idx < 0 && img_layer_idx < 0) { tensors_ = ModelTensors(config); } else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx && - img_layer_idx < config.vit_config.layer_configs.size()) { + img_layer_idx < + static_cast(config.vit_config.layer_configs.size())) { const auto& layer_config = config.vit_config.layer_configs[img_layer_idx]; tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx); } else if (0 <= llm_layer_idx && - llm_layer_idx < config.layer_configs.size()) { + llm_layer_idx < static_cast(config.layer_configs.size())) { const auto& layer_config = config.layer_configs[llm_layer_idx]; tensors_ = LLMLayerTensors(config, layer_config, reshape_att); } diff --git a/gemma/tensor_index.h b/gemma/tensor_index.h index dc6b86c..a1da249 100644 --- a/gemma/tensor_index.h +++ b/gemma/tensor_index.h @@ -54,6 +54,28 @@ struct TensorInfo { bool cols_take_extra_dims = false; }; +// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for +// not-present tensors such as ViT in a text-only model. +static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) { + if (tensor == nullptr) return Extents2D(0, 0); + + size_t cols = tensor->shape.back(); + size_t rows = 1; + if (tensor->cols_take_extra_dims) { + rows = tensor->shape[0]; + for (size_t i = 1; i < tensor->shape.size() - 1; ++i) { + cols *= tensor->shape[i]; + } + } else { // rows take extra dims + for (size_t i = 0; i < tensor->shape.size() - 1; ++i) { + rows *= tensor->shape[i]; + } + } + // Sometimes only one of rows or cols is zero; set both for consistency. + if (rows == 0 || cols == 0) rows = cols = 0; + return Extents2D(rows, cols); +} + // Universal index of tensor information, which can be built for a specific // layer_idx. class TensorIndex { @@ -96,6 +118,16 @@ class TensorIndex { std::unordered_map name_map_; }; +static inline TensorIndex TensorIndexLLM(const ModelConfig& config, + size_t llm_layer_idx) { + return TensorIndex(config, static_cast(llm_layer_idx), -1, false); +} + +static inline TensorIndex TensorIndexImg(const ModelConfig& config, + size_t img_layer_idx) { + return TensorIndex(config, -1, static_cast(img_layer_idx), false); +} + } // namespace gcpp #endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_ diff --git a/gemma/weights.cc b/gemma/weights.cc index d281391..bef76ae 100644 --- a/gemma/weights.cc +++ b/gemma/weights.cc @@ -29,6 +29,7 @@ #include "compression/shared.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "util/mat.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" // HWY_ABORT #include "hwy/contrib/thread_pool/thread_pool.h" @@ -118,7 +119,7 @@ struct TensorSaver { weights.ForEachTensor( {&weights}, fet, [&writer](const char* name, hwy::Span tensors) { - tensors[0]->CallUpcasted(writer, name); + CallUpcasted(tensors[0]->GetType(), tensors[0], writer, name); }); } }; @@ -155,11 +156,11 @@ class WeightInitializer { WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {} void operator()(const char* name, hwy::Span tensors) { - float* data = tensors[0]->data(); - for (size_t i = 0; i < tensors[0]->NumElements(); ++i) { + float* data = tensors[0]->RowT(0); + for (size_t i = 0; i < tensors[0]->Extents().Area(); ++i) { data[i] = dist_(gen_); } - tensors[0]->set_scale(1.0f); + tensors[0]->SetScale(1.0f); } private: @@ -226,11 +227,11 @@ void ModelWeightsStorage::LogWeightStats() { {float_weights_.get()}, ForEachType::kInitNoToc, [&total_weights](const char* name, hwy::Span tensors) { const MatPtr& tensor = *tensors[0]; - if (tensor.scale() != 1.0f) { - printf("[scale=%f] ", tensor.scale()); + if (tensor.Scale() != 1.0f) { + printf("[scale=%f] ", tensor.Scale()); } - LogVec(name, tensor.data(), tensor.NumElements()); - total_weights += tensor.NumElements(); + LogVec(name, tensor.RowT(0), tensor.Extents().Area()); + total_weights += tensor.Extents().Area(); }); printf("%-20s %12zu\n", "Total", total_weights); } @@ -258,8 +259,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type, } template <> -void LayerWeightsPtrs::Reshape(MatStorage* storage) { - if (attn_vec_einsum_w.data() == nullptr) return; +void LayerWeightsPtrs::Reshape(MatOwner* storage) { + if (!attn_vec_einsum_w.HasPtr()) return; const size_t model_dim = layer_config.model_dim; const size_t heads = layer_config.heads; @@ -267,8 +268,7 @@ void LayerWeightsPtrs::Reshape(MatStorage* storage) { // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. if (storage != nullptr) { - storage->Allocate(); - att_weights.SetPtr(*storage); + storage->AllocateFor(att_weights, MatPadding::kPacked); } const hwy::HWY_NAMESPACE::ScalableTag df; @@ -279,7 +279,7 @@ void LayerWeightsPtrs::Reshape(MatStorage* storage) { hwy::AllocateAligned(model_dim * heads * qkv_dim); HWY_NAMESPACE::DecompressAndZeroPad( - df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0, + df, MakeSpan(attn_vec_einsum_w.Packed(), model_dim * heads * qkv_dim), 0, attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim); for (size_t m = 0; m < model_dim; ++m) { @@ -296,10 +296,10 @@ void LayerWeightsPtrs::Reshape(MatStorage* storage) { HWY_NAMESPACE::Compress( att_weights_tmp.get(), model_dim * heads * qkv_dim, work, - MakeSpan(att_weights.data(), model_dim * heads * qkv_dim), + MakeSpan(att_weights.Packed(), model_dim * heads * qkv_dim), /*packed_ofs=*/0, pool); - att_weights.set_scale(attn_vec_einsum_w.scale()); + att_weights.SetScale(attn_vec_einsum_w.Scale()); } } // namespace gcpp diff --git a/gemma/weights.h b/gemma/weights.h index 5fd544b..3cb025e 100644 --- a/gemma/weights.h +++ b/gemma/weights.h @@ -31,12 +31,32 @@ #include "gemma/common.h" #include "gemma/configs.h" #include "gemma/tensor_index.h" +#include "util/mat.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" namespace gcpp { +static inline std::string CacheName(const MatPtr& mat, int layer = -1, + char separator = ' ', int index = -1) { + // Already used/retired: s, S, n, 1 + const char prefix = mat.GetType() == Type::kF32 ? 'F' + : mat.GetType() == Type::kBF16 ? 'B' + : mat.GetType() == Type::kSFP ? '$' + : mat.GetType() == Type::kNUQ ? '2' + : '?'; + std::string name = std::string(1, prefix) + mat.Name(); + if (layer >= 0 || index >= 0) { + name += '_'; + if (layer >= 0) name += std::to_string(layer); + if (index >= 0) { + name += separator + std::to_string(index); + } + } + return name; +} + // Different tensors need to appear in a ForEachTensor, according to what is // happening. enum class ForEachType { @@ -181,10 +201,10 @@ struct LayerWeightsPtrs { // Initializes att_weights from attn_vec_einsum_w, hence this must be called // after loading weights via ForEachTensor. // TODO: update compression/convert_weights to bake this in. - void Reshape(MatStorage* storage) { + void Reshape(MatOwner* storage) { static_assert(!hwy::IsSame()); - if (attn_vec_einsum_w.data() == nullptr) return; + if (!attn_vec_einsum_w.HasPtr()) return; const size_t model_dim = layer_config.model_dim; const size_t heads = layer_config.heads; @@ -192,33 +212,33 @@ struct LayerWeightsPtrs { // Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim]. if (storage != nullptr) { - storage->Allocate(); - att_weights.SetPtr(*storage); + storage->AllocateFor(att_weights, MatPadding::kPacked); } for (size_t m = 0; m < model_dim; ++m) { - Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim; + Weight* HWY_RESTRICT out_row = + att_weights.template RowT(0) + m * heads * qkv_dim; for (size_t h = 0; h < heads; ++h) { - hwy::CopyBytes( - attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim, - out_row + h * qkv_dim, qkv_dim * sizeof(Weight)); + hwy::CopyBytes(attn_vec_einsum_w.template RowT(0) + + h * model_dim * qkv_dim + m * qkv_dim, + out_row + h * qkv_dim, qkv_dim * sizeof(Weight)); } } - att_weights.set_scale(attn_vec_einsum_w.scale()); + att_weights.SetScale(attn_vec_einsum_w.Scale()); } ArrayT key_norm_scale; ArrayT query_norm_scale; // Used by ForEachTensor for per-layer tensors. -#define GEMMA_CALL_FUNC(member) \ - { \ - for (int i = 0; i < ptrs.size(); ++i) { \ - tensors[i] = &ptrs[i]->member; \ - } \ - if (tensors[0]->Ptr() != nullptr || fet != ForEachType::kIgnoreNulls) { \ - func(ptrs[0]->member.CacheName(layer_idx, sep, sep_index).c_str(), \ - hwy::Span(tensors.data(), ptrs.size())); \ - } \ +#define GEMMA_CALL_FUNC(member) \ + { \ + for (int i = 0; i < ptrs.size(); ++i) { \ + tensors[i] = &ptrs[i]->member; \ + } \ + if (tensors[0]->HasPtr() || fet != ForEachType::kIgnoreNulls) { \ + func(CacheName(ptrs[0]->member, layer_idx, sep, sep_index).c_str(), \ + hwy::Span(tensors.data(), ptrs.size())); \ + } \ } template @@ -307,19 +327,18 @@ struct LayerWeightsPtrs { void ZeroInit(int layer_idx) { ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls, [](const char*, hwy::Span tensors) { - tensors[0]->ZeroInit(); + gcpp::ZeroInit(*tensors[0]); }); } // Allocates memory for all the tensors in the layer. // Note that this is slow and only used for a stand-alone layer. - void Allocate(std::vector& layer_storage) { + void Allocate(std::vector& layer_storage) { ForEachTensor( {this}, /*layer_idx=*/0, ForEachType::kInitNoToc, [&layer_storage](const char* name, hwy::Span tensors) { - layer_storage.emplace_back(*tensors[0]); - layer_storage.back().Allocate(); - tensors[0]->SetPtr(layer_storage.back()); + layer_storage.push_back(MatOwner()); + layer_storage.back().AllocateFor(*tensors[0], MatPadding::kPacked); }); } }; @@ -393,11 +412,9 @@ struct ModelWeightsPtrs { // Called by weights.cc after Loading, before att_w has been allocated. void AllocAndCopyWithTranspose(hwy::ThreadPool& pool, - std::vector& model_storage) { + std::vector& model_storage) { size_t storage_index = model_storage.size(); - for (auto& layer : c_layers) { - model_storage.emplace_back(layer.att_weights); - } + model_storage.resize(model_storage.size() + c_layers.size()); pool.Run(0, c_layers.size(), [this, &model_storage, storage_index](uint64_t layer, size_t /*thread*/) { @@ -412,8 +429,8 @@ struct ModelWeightsPtrs { } void ZeroInit() { - embedder_input_embedding.ZeroInit(); - final_norm_scale.ZeroInit(); + gcpp::ZeroInit(embedder_input_embedding); + gcpp::ZeroInit(final_norm_scale); for (size_t i = 0; i < c_layers.size(); ++i) { c_layers[i].ZeroInit(i); } @@ -430,21 +447,21 @@ struct ModelWeightsPtrs { return &vit_layers[layer]; } - void Allocate(std::vector& model_storage, hwy::ThreadPool& pool) { + void Allocate(std::vector& model_storage, hwy::ThreadPool& pool) { std::vector model_toc; ForEachTensor( {this}, ForEachType::kInitNoToc, [&model_toc, &model_storage](const char*, hwy::Span tensors) { model_toc.push_back(tensors[0]); - model_storage.emplace_back(*tensors[0]); + model_storage.push_back(MatOwner()); }); // Allocate in parallel using the pool. pool.Run(0, model_toc.size(), [&model_toc, &model_storage](uint64_t task, size_t /*thread*/) { // model_storage may have had content before we started. size_t idx = task + model_storage.size() - model_toc.size(); - model_storage[idx].Allocate(); - model_toc[task]->SetPtr(model_storage[idx]); + model_storage[idx].AllocateFor(*model_toc[task], + MatPadding::kPacked); }); } @@ -453,8 +470,7 @@ struct ModelWeightsPtrs { ForEachTensor({this, const_cast*>(&other)}, ForEachType::kIgnoreNulls, [](const char*, hwy::Span tensors) { - hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(), - tensors[1]->SizeBytes()); + CopyMat(*tensors[1], *tensors[0]); }); } @@ -467,10 +483,10 @@ struct ModelWeightsPtrs { [&scales, &scale_pos, this](const char*, hwy::Span tensors) { if (this->scale_names.count(tensors[0]->Name())) { if (scale_pos < scales.size()) { - tensors[0]->set_scale(scales[scale_pos]); + tensors[0]->SetScale(scales[scale_pos]); } else { - float scale = ScaleWeights(tensors[0]->data(), - tensors[0]->NumElements()); + float scale = ScaleWeights(tensors[0]->RowT(0), + tensors[0]->Extents().Area()); scales.push_back(scale); } ++scale_pos; @@ -615,9 +631,9 @@ class ModelWeightsStorage { std::unique_ptr> sfp_weights_; std::unique_ptr> nuq_weights_; // Storage for all the matrices and vectors. - std::vector model_storage_; + std::vector model_storage_; }; } // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ \ No newline at end of file +#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_ diff --git a/ops/bench_matmul.cc b/ops/bench_matmul.cc index 30ec634..956c5be 100644 --- a/ops/bench_matmul.cc +++ b/ops/bench_matmul.cc @@ -31,15 +31,12 @@ #include #include -#include #include -#include "compression/compress.h" #include "compression/shared.h" #include "ops/matmul.h" -#include "util/allocator.h" #include "util/basics.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/nanobenchmark.h" #include "hwy/profiler.h" @@ -53,8 +50,8 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "compression/test_util-inl.h" #include "ops/matmul-inl.h" -#include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -63,59 +60,6 @@ extern int64_t first_target; namespace HWY_NAMESPACE { -using FloatPtr = hwy::AlignedFreeUniquePtr; - -template -using MatStoragePtr = std::unique_ptr>; - -// Generates inputs: deterministic, within max SfpStream range. -template -MatStoragePtr GenerateMat(const Extents2D extents, - hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - auto mat = - std::make_unique>("mat", extents.rows, extents.cols); - FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - HWY_ASSERT(content); - const float scale = - SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1); - pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { - for (size_t c = 0; c < extents.cols; c++) { - float f = static_cast(r * extents.cols + c) * scale; - if ((r + c) & 1) f = -f; // Also generate some negative values. - content[r * extents.cols + c] = f; - } - }); - - CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); - mat->set_scale(0.6f); // Arbitrary value, different from 1. - return mat; -} - -// extents describes the transposed matrix. -template -MatStoragePtr GenerateTransposedMat(const Extents2D extents, - hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - auto mat = - std::make_unique>("trans", extents.rows, extents.cols); - FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - const float scale = - SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1); - pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { - for (size_t c = 0; c < extents.cols; c++) { - float f = static_cast(c * extents.rows + r) * scale; - if ((r + c) & 1) f = -f; // Also generate some negative values. - content[r * extents.cols + c] = f; - } - }); - - CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); - // Arbitrary value, different from 1, must match GenerateMat. - mat->set_scale(0.6f); - return mat; -} - void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, std::vector& times, MMPerKey* per_key) { std::sort(times.begin(), times.end()); @@ -135,7 +79,8 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents, // M = A rows, K = A cols, N = C cols. template void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { - hwy::ThreadPool& pool = env.parallel.Pools().Pool(0); + const Allocator2& allocator = env.ctx.allocator; + hwy::ThreadPool& pool = env.ctx.pools.Pool(0); if (env.print_config || env.print_measurement) { fprintf(stderr, "\n"); } @@ -147,24 +92,23 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { const Extents2D B_extents(N, K); // already transposed const Extents2D C_extents(M, N); - RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); - RowVectorBatch c_batch = AllocateAlignedRows(C_extents); + RowVectorBatch c_slow_batch = + AllocateAlignedRows(allocator, C_extents); + RowVectorBatch c_batch = AllocateAlignedRows(allocator, C_extents); - std::unique_ptr> add_storage; + MatStorageT add_storage("add", Extents2D(), MatPadding::kPacked); if (add) { add_storage = GenerateMat(Extents2D(1, N), pool); - HWY_ASSERT(add_storage); - add_storage->set_scale(1.0f); + add_storage.SetScale(1.0f); } - MatStoragePtr a = GenerateMat(A_extents, pool); - MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); - HWY_ASSERT(a && b_trans); - const auto A = ConstMatFromWeights(*a); - const auto B = ConstMatFromWeights(*b_trans); + MatStorageT a = GenerateMat(A_extents, pool); + MatStorageT b_trans = GenerateTransposedMat(B_extents, pool); + const auto A = ConstMatFromWeights(a); + const auto B = ConstMatFromWeights(b_trans); - const float* add_row = add ? add_storage->data_scale1() : nullptr; - const RowPtr C = RowPtrFromBatch(c_batch); + const float* add_row = add ? add_storage.PackedScale1() : nullptr; + const RowPtr C = RowPtrFromBatch(allocator, c_batch); // Fewer reps for large batch sizes, which take longer. const size_t num_samples = M < 32 ? 20 : 12; @@ -174,11 +118,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { // Ensure usage conditions are set before autotuning. Both binding and // spinning may materially affect the choice of config. No harm in calling // BindB/C if there is a single package: they will be a no-op. - BindB(B_extents.rows, sizeof(TC), B, env.parallel); - BindC(A_extents.rows, C, env.parallel); + BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel); + BindC(allocator, A_extents.rows, C, env.parallel); Tristate use_spinning = Tristate::kDefault; - env.parallel.Pools().MaybeStartSpinning(use_spinning); + env.ctx.pools.MaybeStartSpinning(use_spinning); // env.print_config = true; // env.print_measurement = true; @@ -198,7 +142,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) { if (per_key->autotune.Best()) times.push_back(elapsed); } hwy::PreventElision(keep); - env.parallel.Pools().MaybeStopSpinning(use_spinning); + env.ctx.pools.MaybeStopSpinning(use_spinning); PrintSpeed(A_extents, B_extents, times, per_key); } @@ -216,17 +160,11 @@ void BenchAllMatMul() { return; } - const size_t max_threads = 0; // no limit - const BoundedSlice package_slice; // all packages/sockets - const BoundedSlice cluster_slice; // all clusters/CCX - const BoundedSlice lp_slice; // default to all cores (per package). - const BoundedTopology topology(package_slice, cluster_slice, lp_slice); - Allocator::Init(topology, /*enable_bind=*/true); - NestedPools pools(topology, max_threads, Tristate::kDefault); - fprintf(stderr, "BenchAllMatMul %s %s\n", topology.TopologyString(), - pools.PinString()); + ThreadingContext2& ctx = ThreadingContext2::Get(); + fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(), + ctx.pools.PinString()); - MatMulEnv env(topology, pools); + MatMulEnv env(ctx); for (size_t batch_size : {1, 4, 128, 512}) { constexpr bool kAdd = false; diff --git a/ops/dot-inl.h b/ops/dot-inl.h index f5282f2..08a5ca8 100644 --- a/ops/dot-inl.h +++ b/ops/dot-inl.h @@ -16,6 +16,7 @@ #include #include "compression/compress.h" +#include "util/mat.h" #include "hwy/base.h" #include "hwy/profiler.h" @@ -379,10 +380,7 @@ template HWY_INLINE float Dot(const MatPtrT& w, size_t w_ofs, const VT* vec_aligned, size_t num) { const hn::ScalableTag d; - return w.scale() * Dot(d, - MakeConstSpan(reinterpret_cast(w.Ptr()), - w.NumElements()), - w_ofs, vec_aligned, num); + return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/dot_test.cc b/ops/dot_test.cc index 6aa970a..02c8d50 100644 --- a/ops/dot_test.cc +++ b/ops/dot_test.cc @@ -28,9 +28,8 @@ #include "compression/shared.h" #include "util/allocator.h" -#include "util/app.h" #include "util/test_util.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" @@ -805,7 +804,7 @@ class DotStats { ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4); ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f); // Updating Kahan's FastTwoSums to TwoSums does help a bit. - ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4); + ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.8E-4); ASSERT_INSIDE(kPairwise, 4.5E-4, s_l1s[kPairwise].Mean(), 4E-3); ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f); @@ -1000,9 +999,7 @@ struct TestShortDotsT { const size_t N = hn::Lanes(d); const hn::ScalableTag df; // for CallDot - const AppArgs app; - BoundedTopology topology(CreateTopology(app)); - NestedPools pools = CreatePools(topology, app); + const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator; CompressWorkingSet work; std::mt19937 rng; rng.seed(12345); @@ -1014,14 +1011,14 @@ struct TestShortDotsT { // hence they require padding to one vector. const size_t padded_num = hwy::RoundUpTo(num, N); const size_t packed_num = CompressedArrayElements(num); - RowVectorBatch raw_w(Extents2D(1, padded_num)); - RowVectorBatch raw_v(Extents2D(1, padded_num)); - RowVectorBatch weights(Extents2D(1, packed_num)); + RowVectorBatch raw_w(allocator, Extents2D(1, padded_num)); + RowVectorBatch raw_v(allocator, Extents2D(1, padded_num)); + RowVectorBatch weights(allocator, Extents2D(1, packed_num)); const PackedSpan w(weights.Batch(0), packed_num); - RowVectorBatch vectors(Extents2D(1, num)); + RowVectorBatch vectors(allocator, Extents2D(1, num)); const PackedSpan v(vectors.Batch(0), num); - RowVectorBatch bufs(Extents2D(1, num)); + RowVectorBatch bufs(allocator, Extents2D(1, num)); double* HWY_RESTRICT buf = bufs.Batch(0); for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { @@ -1099,10 +1096,21 @@ void TestAllDot() { return; } + constexpr size_t kMaxWorkers = 15; + + // Reset with cap on workers because we only support `kMaxWorkers`. + ThreadingContext2::ThreadHostileInvalidate(); + ThreadingArgs threading_args; + threading_args.max_packages = 1; + threading_args.max_clusters = 1; + threading_args.max_lps = kMaxWorkers - 1; + ThreadingContext2::SetArgs(threading_args); + ThreadingContext2& ctx = ThreadingContext2::Get(); + const Allocator2& allocator = ctx.allocator; + { // ensure no profiler zones are active const hn::ScalableTag df; - constexpr size_t kMaxWorkers = 15; std::mt19937 rngs[kMaxWorkers]; for (size_t i = 0; i < kMaxWorkers; ++i) { rngs[i].seed(12345 + 65537 * i); @@ -1110,44 +1118,43 @@ void TestAllDot() { constexpr size_t kReps = hn::AdjustedReps(40); const size_t num = 24 * 1024; - const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1), - BoundedSlice()); - NestedPools pools(topology, kMaxWorkers - 1, /*pin=*/Tristate::kDefault); - RowVectorBatch a(Extents2D(kMaxWorkers, num)); - RowVectorBatch b(Extents2D(kMaxWorkers, num)); - RowVectorBatch bufs(Extents2D(kMaxWorkers, num)); + RowVectorBatch a(allocator, Extents2D(kMaxWorkers, num)); + RowVectorBatch b(allocator, Extents2D(kMaxWorkers, num)); + RowVectorBatch bufs(allocator, Extents2D(kMaxWorkers, num)); std::array all_stats; - pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) { - float* HWY_RESTRICT pa = a.Batch(thread); - float* HWY_RESTRICT pb = b.Batch(thread); - double* HWY_RESTRICT buf = bufs.Batch(thread); - const PackedSpan a_span(pa, num); - DotStats& stats = all_stats[thread]; - const double cond = - GenerateIllConditionedInputs(num, pa, pb, rngs[thread]); + ctx.pools.Cluster(0, 0).Run( + 0, kReps, [&](const uint32_t rep, size_t thread) { + float* HWY_RESTRICT pa = a.Batch(thread); + float* HWY_RESTRICT pb = b.Batch(thread); + double* HWY_RESTRICT buf = bufs.Batch(thread); + const PackedSpan a_span(pa, num); + DotStats& stats = all_stats[thread]; + const double cond = + GenerateIllConditionedInputs(num, pa, pb, rngs[thread]); - const float dot_exact = ExactDot(pa, pb, num, buf); + const float dot_exact = ExactDot(pa, pb, num, buf); - float dots[kVariants] = {}; - double times[kVariants] = {}; - for (size_t variant = 0; variant < kVariants; ++variant) { - constexpr size_t kTimeReps = hn::AdjustedReps(10); - std::array elapsed; - for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) { - const double start = hwy::platform::Now(); - dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num); - hwy::PreventElision(*pa); - elapsed[time_rep] = hwy::platform::Now() - start; - } - dots[variant] /= kTimeReps; - times[variant] = TrimmedMean(elapsed.data(), kTimeReps); - } + float dots[kVariants] = {}; + double times[kVariants] = {}; + for (size_t variant = 0; variant < kVariants; ++variant) { + constexpr size_t kTimeReps = hn::AdjustedReps(10); + std::array elapsed; + for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) { + const double start = hwy::platform::Now(); + dots[variant] += + CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num); + hwy::PreventElision(*pa); + elapsed[time_rep] = hwy::platform::Now() - start; + } + dots[variant] /= kTimeReps; + times[variant] = TrimmedMean(elapsed.data(), kTimeReps); + } - stats.NotifyTimes(times); - stats.NotifyRep(num, cond, dot_exact, dots); - stats.NotifyRatios(); - }); + stats.NotifyTimes(times); + stats.NotifyRep(num, cond, dot_exact, dots); + stats.NotifyRatios(); + }); DotStats& stats = all_stats[0]; for (size_t i = 1; i < kMaxWorkers; ++i) { diff --git a/ops/gemma_matvec_test.cc b/ops/gemma_matvec_test.cc index 6982b20..c862e28 100644 --- a/ops/gemma_matvec_test.cc +++ b/ops/gemma_matvec_test.cc @@ -25,7 +25,7 @@ #include // std::abs #include -#include "compression/compress.h" +#include "util/mat.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -37,6 +37,7 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/highway.h" // After highway.h +#include "compression/compress-inl.h" #include "ops/matvec-inl.h" #include "hwy/tests/test_util-inl.h" @@ -48,18 +49,18 @@ using FloatPtr = hwy::AlignedFreeUniquePtr; FloatPtr SimpleMatVecAdd(const MatStorageT& mat, const FloatPtr& vec, const FloatPtr& add) { - FloatPtr raw_mat = hwy::AllocateAligned(mat.NumElements()); + const size_t num = mat.Rows() * mat.Cols(); + FloatPtr raw_mat = hwy::AllocateAligned(num); FloatPtr out = hwy::AllocateAligned(mat.Rows()); HWY_ASSERT(raw_mat && out); const hn::ScalableTag df; - DecompressAndZeroPad(df, MakeSpan(mat.data(), mat.NumElements()), 0, - raw_mat.get(), mat.NumElements()); + DecompressAndZeroPad(df, mat.Span(), 0, raw_mat.get(), num); for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) { out[idx_row] = 0.0f; for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) { out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col]; } - out[idx_row] *= mat.scale(); + out[idx_row] *= mat.Scale(); out[idx_row] += add[idx_row]; } return out; @@ -69,8 +70,10 @@ template std::unique_ptr> GenerateMat(size_t offset, hwy::ThreadPool& pool) { gcpp::CompressWorkingSet ws; - auto mat = std::make_unique>("TestMat", kOuter, kInner); - FloatPtr raw_mat = hwy::AllocateAligned(mat->NumElements()); + const Extents2D extents(kOuter, kInner); + auto mat = std::make_unique>("TestMat", extents, + MatPadding::kPacked); + FloatPtr raw_mat = hwy::AllocateAligned(extents.Area()); HWY_ASSERT(raw_mat); const float scale = 1.0f / kInner; pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) { @@ -80,8 +83,8 @@ std::unique_ptr> GenerateMat(size_t offset, } }); - CompressScaled(raw_mat.get(), mat->NumElements(), ws, *mat, pool); - mat->set_scale(1.9f); // Arbitrary value, different from 1. + CompressScaled(raw_mat.get(), extents.Area(), ws, *mat, pool); + mat->SetScale(1.9f); // Arbitrary value, different from 1. return mat; } diff --git a/ops/matmul-inl.h b/ops/matmul-inl.h index 2ff959d..e6491ee 100644 --- a/ops/matmul-inl.h +++ b/ops/matmul-inl.h @@ -15,6 +15,7 @@ #include #include +#include #include @@ -22,9 +23,8 @@ #include "ops/matmul.h" // IWYU pragma: export #include "util/allocator.h" #include "util/basics.h" -#include "util/threading.h" +#include "util/threading_context.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" #include "hwy/timer.h" @@ -866,6 +866,8 @@ class MMPerPackage { const IndexRange& range_np) : args_(args), pkg_idx_(pkg_idx), + // May be overwritten with a view of A, if already BF16. + A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())), range_np_(range_np), mr_(config.MR()), ranges_mc_(config.RangesOfMC(A.Extents().rows)), @@ -873,14 +875,11 @@ class MMPerPackage { ranges_nc_(config.RangesOfNC(range_np)), order_(config.Order()), inner_tasks_(config.InnerTasks()), - out_(config.Out()) { - // May be overwritten with a view of A, if already BF16. - A_ = args_.env->storage.A(pkg_idx, A.Extents()); - { - MMZone zone; - zone.MaybeEnter("MM.DecompressA", args_); - A_ = DecompressA(A); - } + out_(config.Out()), + line_bytes_(args.env->ctx.allocator.LineBytes()) { + MMZone zone; + zone.MaybeEnter("MM.DecompressA", args_); + A_ = DecompressA(A); } // B is decompressed several call layers lower, but not all member functions @@ -909,14 +908,14 @@ class MMPerPackage { // Compute size of per-worker storage for `kNR` row ranges of B. Stack // allocation avoids passing a worker index. static constexpr size_t B_stride_max_ = - StrideForCyclicOffsets(MMStorage::kMaxKC); + MaxStrideForCyclicOffsets(MMStorage::kMaxKC); static constexpr size_t B_storage_max_ = - kNR * B_stride_max_ + Allocator::MaxQuantumBytes() / sizeof(BF16); + kNR * B_stride_max_ + Allocator2::MaxQuantum(); // Granularity of `ForNP`. B rows produce C columns, so we // want a multiple of the line size to prevent false sharing. - static size_t MultipleNP(size_t sizeof_TC) { - return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC); + size_t MultipleNP(size_t sizeof_TC) const { + return HWY_MAX(kNR, line_bytes_ / sizeof_TC); } // Single M and K, parallel N. Fills all of C directly. @@ -931,14 +930,16 @@ class MMPerPackage { const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K); - const size_t B_stride = StrideForCyclicOffsets(K); + const size_t B_stride = + StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum()); // Similar to `loop_nc` below, but here we hoisted `A_view`. args_.env->parallel.ForNP( range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_, [&](const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(B_storage, K, B_stride); + const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K, + B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -972,7 +973,9 @@ class MMPerPackage { auto out_tag) HWY_ATTR { const size_t kc = range_kc.Num(); const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc); - const RowPtrBF B_view(B_storage, kc, StrideForCyclicOffsets(kc)); + const RowPtrBF B_view( + args_.env->ctx.allocator, B_storage, kc, + StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum())); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -1027,7 +1030,8 @@ class MMPerPackage { HWY_DASSERT(ranges_kc_.NumTasks() == 1); const IndexRange& range_K = ranges_kc_.Range(0); const size_t K = range_K.Num(); - const size_t B_stride = StrideForCyclicOffsets(K); + const size_t B_stride = + StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum()); // Sequential loop over NC/MC/KC, similar to `loop_nc` below // except for the profiler strings and `out_tag`. @@ -1036,7 +1040,8 @@ class MMPerPackage { [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K); HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(B_storage, K, B_stride); + const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K, + B_stride); for (size_t row_b = range_nc.begin(); row_b < range_nc.end(); row_b += kNR) { @@ -1062,7 +1067,8 @@ class MMPerPackage { zone.MaybeEnter("MM.NT_MT_K", args_); const size_t kc_max = ranges_kc_.TaskSize(); HWY_DASSERT(kc_max <= MMStorage::kMaxKC); - const size_t B_stride = StrideForCyclicOffsets(kc_max); + const size_t B_stride = StrideForCyclicOffsets( + kc_max, args_.env->ctx.allocator.Quantum()); // Sequential loop over NC/MC/KC, for when the M/N loops are // already parallel. This is B3A2C0 in MOMMS terminology: we read // `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`. @@ -1088,7 +1094,8 @@ class MMPerPackage { ranges_mc_, ranges_nc_, pkg_idx_, [&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR { HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS - const RowPtrBF B_view(B_storage, kc_max, B_stride); + const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max, + B_stride); // Peel off the first iteration of the kc loop: avoid // zero-initializing `partial` by writing into it. @@ -1151,8 +1158,7 @@ class MMPerPackage { // At least one vector, otherwise DecompressAndZeroPad will add // padding, which might overwrite neighboring tasks. Also a whole cache // line to avoid false sharing. - const size_t multiple_K = - HWY_MAX(NBF, Allocator::LineBytes() / sizeof(BF16)); + const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16)); args_.env->parallel.ForNP( all_K, multiple_K, inner_tasks, pkg_idx_, @@ -1170,6 +1176,7 @@ class MMPerPackage { // Autotuning wrapper for `DoDecompressA`. template HWY_INLINE RowPtrBF DecompressA(const ConstMat& A) const { + const Allocator2& allocator = args_.env->ctx.allocator; MMAutoTune& autotune = args_.per_key->autotune_par_a[pkg_idx_]; // If already BF16, maybe return a view: if constexpr (hwy::IsSame()) { @@ -1177,7 +1184,8 @@ class MMPerPackage { const size_t NBF = hn::Lanes(hn::ScalableTag()); if (HWY_LIKELY(A.extents.cols % NBF == 0)) { const BF16* pos = A.ptr + A.Row(0); - return RowPtrBF(const_cast(pos), A.extents.cols, A.Stride()); + return RowPtrBF(allocator, const_cast(pos), A.extents.cols, + A.Stride()); } } @@ -1251,6 +1259,7 @@ class MMPerPackage { const MMOrder order_; const size_t inner_tasks_; const MMOut out_; + const size_t line_bytes_; }; // MMPerPackage // Stateless, wraps member functions. @@ -1308,6 +1317,7 @@ template HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, const float* HWY_RESTRICT add, MatMulEnv& env, const RowPtr& C) { + const Allocator2& allocator = env.ctx.allocator; const size_t M = A.Extents().rows; const size_t K = A.Extents().cols; const size_t N = B.Extents().rows; @@ -1315,11 +1325,11 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, intptr_t index = MMImpl::IndexOfKey(key, env.keys); // First time we see this shape/key. if (HWY_UNLIKELY(index < 0)) { - env.keys.Append(key); + env.keys.Append(key, allocator); size_t max_packages = MMParallel::kMaxPackages; // For low-batch, multiple sockets only help if binding is enabled. - if (!Allocator::ShouldBind() && M <= 4) { + if (!allocator.ShouldBind() && M <= 4) { max_packages = 1; } @@ -1351,8 +1361,9 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat& A, const ConstMat& B, HWY_ASSERT(N % kNR == 0); // Negligible CPU time. - tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR, - per_key.ranges_np, env.print_config)); + tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC), + MMKernel::kMaxMR, kNR, per_key.ranges_np, + env.print_config)); } const MMConfig& cfg = tuner.NextConfig(); diff --git a/ops/matmul.cc b/ops/matmul.cc index edca38c..0131bc6 100644 --- a/ops/matmul.cc +++ b/ops/matmul.cc @@ -60,10 +60,11 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim, // and holds most of their arguments in member variables. class GenerateCandidates { public: - GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC, - size_t max_mr, size_t nr, + GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N, + size_t sizeof_TC, size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) - : M_(M), + : allocator_(allocator), + M_(M), K_(K), N_(N), sizeof_TC_(sizeof_TC), @@ -73,8 +74,8 @@ class GenerateCandidates { // `RangesOf*`. Must be a vector multiple. The previous/next cache line // is likely still in L1, but we expect K > 1000 and might as well round // up to the line size. - kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))), - nc_multiple_(Allocator::StepBytes() / sizeof_TC), + kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))), + nc_multiple_(allocator.StepBytes() / sizeof_TC), ranges_np_(ranges_np), print_config_(print_config) {} @@ -172,7 +173,7 @@ class GenerateCandidates { // subtract the output and buf, and allow using more than the actual L1 // size. This results in an overestimate, and the loop below will propose // the next few smaller values for the autotuner to evaluate. - const size_t bytes_ab = Allocator::L1Bytes() * 3; + const size_t bytes_ab = allocator_.L1Bytes() * 3; const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16); size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes); kc_max = @@ -220,8 +221,8 @@ class GenerateCandidates { // packed B. We want `mc * kc` elements of A to fit in L2, alongside // `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of // partial. - const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes(); - size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc); + const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes(); + size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc); mc_max = HWY_MIN(mc_max, MMStorage::kMaxM); HWY_DASSERT(mc_max != 0); mc_max = HWY_MIN(mc_max, M_); @@ -264,7 +265,7 @@ class GenerateCandidates { // Otherwise, leave it unbounded. if (M_ > mr) { const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes); - nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc); + nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc); nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max); } HWY_DASSERT(nc_max != 0); @@ -351,6 +352,7 @@ class GenerateCandidates { } } + const Allocator2& allocator_; const size_t M_; const size_t K_; const size_t N_; @@ -370,25 +372,26 @@ class GenerateCandidates { } // namespace // Facade to avoid exposing `GenerateCandidates` in the header. -std::vector MMCandidates(size_t M, size_t K, size_t N, - size_t sizeof_TC, size_t max_mr, size_t nr, +std::vector MMCandidates(const Allocator2& allocator, size_t M, + size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config) { - return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np, - print_config)(); + return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr, + ranges_np, print_config)(); } // Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote // memory accesses or false sharing, unless there are insufficient per-package // rows for that. -static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr, - size_t num_packages) { - size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC; +static size_t NPMultiple(const Allocator2& allocator, size_t N, + size_t sizeof_TC, size_t nr, size_t num_packages) { + size_t np_multiple = allocator.QuantumBytes() / sizeof_TC; // If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For // `N` < 4096, this can cause significant load imbalance. If split unevenly, // choose a smaller multiple. if (N % (np_multiple * num_packages)) { - const size_t min_multiple = Allocator::LineBytes() / sizeof_TC; + const size_t min_multiple = allocator.LineBytes() / sizeof_TC; np_multiple = PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple); if (HWY_UNLIKELY(np_multiple == 0)) { @@ -408,16 +411,14 @@ static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr, IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr) const { - const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages()); - return StaticPartition(IndexRange(0, N), num_packages, - NPMultiple(N, sizeof_TC, nr, num_packages)); + const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages()); + return StaticPartition( + IndexRange(0, N), num_packages, + NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages)); } -MatMulEnv::MatMulEnv(const BoundedTopology& topology, NestedPools& pools) - : parallel(topology, pools), storage(parallel) { - // Ensure Allocator:Init was called. - HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0); - +MatMulEnv::MatMulEnv(ThreadingContext2& ctx) + : ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) { char cpu100[100]; have_timer_stop = hwy::platform::HaveTimerStop(cpu100); } diff --git a/ops/matmul.h b/ops/matmul.h index dc375d0..768573b 100644 --- a/ops/matmul.h +++ b/ops/matmul.h @@ -24,11 +24,9 @@ #include // IWYU pragma: begin_exports -#include "compression/compress.h" -#include "util/allocator.h" #include "util/basics.h" -#include "util/threading.h" -#include "util/topology.h" +#include "util/mat.h" +#include "util/threading_context.h" #include "hwy/aligned_allocator.h" // Span #include "hwy/base.h" #include "hwy/bit_set.h" @@ -51,33 +49,30 @@ class MMParallel { public: static constexpr size_t kMaxPackages = 4; - // Both references must outlive this object. - MMParallel(const BoundedTopology& topology, NestedPools& pools) - : topology_(topology), pools_(pools) { - HWY_DASSERT(pools_.NumPackages() <= kMaxPackages); + // `ctx` must outlive this object. + MMParallel(ThreadingContext2& ctx) : ctx_(ctx) { + HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages); } - // Used by tests. - NestedPools& Pools() { return pools_; } - // Initial static partitioning of B rows across packages. IndexRangePartition RangesOfNP(size_t max_packages, size_t N, size_t sizeof_TC, size_t nr) const; // For `BindB` and `BindC`. size_t Node(size_t pkg_idx) const { - return topology_.GetCluster(pkg_idx, 0).Node(); + return ctx_.topology.GetCluster(pkg_idx, 0).Node(); } // Calls `func(pkg_idx)` for each package in parallel. template void ForPkg(const size_t max_packages, const Func& func) { - pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()), - [&](uint64_t task, size_t pkg_idx) { - HWY_DASSERT(task == pkg_idx); - (void)task; - func(pkg_idx); - }); + ctx_.pools.AllPackages().Run( + 0, HWY_MIN(max_packages, ctx_.pools.NumPackages()), + [&](uint64_t task, size_t pkg_idx) { + HWY_DASSERT(task == pkg_idx); + (void)task; + func(pkg_idx); + }); } // Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is @@ -87,10 +82,10 @@ class MMParallel { size_t pkg_idx, const Func& func) { HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4); // Single cluster: parallel-for over static partition of `range_np`. - hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); const size_t num_clusters = all_clusters.NumWorkers(); if (num_clusters == 1) { - hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0); + hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0); const IndexRangePartition worker_ranges = StaticPartition( range_np, cluster.NumWorkers() * inner_tasks, nx_multiple); return ParallelizeOneRange( @@ -106,7 +101,7 @@ class MMParallel { ParallelizeOneRange( nx_ranges, all_clusters, [&](const IndexRange& nx_range, const size_t cluster_idx) { - hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); // Parallel-for over sub-ranges of `cluster_range` within the cluster. const IndexRangePartition worker_ranges = StaticPartition( nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple); @@ -122,14 +117,14 @@ class MMParallel { void ForRangesMC_NC(const IndexRangePartition& ranges_mc, const IndexRangePartition& ranges_nc, size_t pkg_idx, const Func& func) { - hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx); + hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx); // `all_clusters` is a pool with one worker per cluster in a package. const size_t num_clusters = all_clusters.NumWorkers(); // Single (big) cluster: collapse two range indices into one parallel-for // to reduce the number of fork-joins. if (num_clusters == 1) { const size_t cluster_idx = 0; - hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); // Low-batch: avoid Divide/Remainder. if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) { return ParallelizeOneRange( @@ -150,7 +145,7 @@ class MMParallel { ParallelizeOneRange( ranges_nc, all_clusters, [&](const IndexRange range_nc, size_t cluster_idx) { - hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx); + hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx); ParallelizeOneRange( ranges_mc, cluster, [&](const IndexRange& range_mc, size_t /*thread*/) { @@ -163,33 +158,33 @@ class MMParallel { template void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx, const Func& func) { - pools_.Pool(pkg_idx).Run( + ctx_.pools.Pool(pkg_idx).Run( range_mc.begin(), range_mc.end(), [&](uint64_t row_a, size_t /*thread*/) { func(row_a); }); } private: - const BoundedTopology& topology_; - NestedPools& pools_; + ThreadingContext2& ctx_; }; template // BF16/float for C, double for partial -void BindC(size_t M, const RowPtr& C, MMParallel& parallel) { - if (!Allocator::ShouldBind()) return; +void BindC(const Allocator2& allocator, size_t M, const RowPtr& C, + MMParallel& parallel) { + if (!allocator.ShouldBind()) return; const IndexRangePartition ranges_np = parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR); - const size_t quantum = Allocator::QuantumBytes() / sizeof(TC); + const size_t quantum = allocator.Quantum(); bool ok = true; for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& cols_c = ranges_np.Range(pkg_idx); const size_t node = parallel.Node(pkg_idx); for (size_t im = 0; im < M; ++im) { - // BindRowsToPackageNodes may not be page-aligned. + // `BindMemory` requires page alignment. const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum); const size_t end = hwy::RoundDownTo(cols_c.end(), quantum); - ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC), - node); + ok &= allocator.BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC), + node); } } if (HWY_UNLIKELY(!ok)) { @@ -212,38 +207,42 @@ class MMStorage { // of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`. static constexpr size_t kMaxKC = 8 * 1024; - explicit MMStorage(MMParallel& parallel) { + MMStorage(const Allocator2& allocator, MMParallel& parallel) + // Per-worker copies of `partial` would be wasteful. We instead allocate + // one instance of the maximum matrix extents because threads write at + // false-sharing-free granularity. + : partial_storage_( + AllocateAlignedRows(allocator, Extents2D(kMaxM, kMaxN))), + // Same stride independent of the actual C.Cols() so we can pre-bind. + partial_(allocator, partial_storage_.All(), kMaxN, + StrideForCyclicOffsets(kMaxN, allocator.Quantum())) { // Per-package allocation so each can decompress A into its own copy. parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) { - pkg_A_[pkg_idx] = AllocateAlignedRows(Extents2D(kMaxM, kMaxK)); + pkg_A_[pkg_idx] = + AllocateAlignedRows(allocator, Extents2D(kMaxM, kMaxK)); - if (Allocator::ShouldBind()) { + if (allocator.ShouldBind()) { const size_t node = parallel.Node(pkg_idx); - if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(), - pkg_A_[pkg_idx].NumBytes(), node)) { + if (!allocator.BindMemory(pkg_A_[pkg_idx].All(), + pkg_A_[pkg_idx].NumBytes(), node)) { HWY_WARN("Failed to bind memory for package %zu", pkg_idx); } } }); - // Per-worker copies of `partial` would be wasteful. We instead allocate - // one instance of the maximum matrix extents because threads write at - // false-sharing-free granularity. - partial_storage_ = AllocateAlignedRows(Extents2D(kMaxM, kMaxN)); - // Same stride independent of the actual C.Cols() so we can pre-bind. - partial_ = RowPtrD(partial_storage_.All(), kMaxN, - StrideForCyclicOffsets(kMaxN)); // Avoid cross-package accesses. - BindC(kMaxM, partial_, parallel); + BindC(allocator, kMaxM, partial_, parallel); } // Returns per-package matrix view. Non-const so that `RowVectorBatch` is // non-const, because `RowPtr` requires a non-const pointer. - RowPtrBF A(size_t pkg_idx, const Extents2D& extents) { + RowPtrBF A(const Allocator2& allocator, size_t pkg_idx, + const Extents2D& extents) { HWY_DASSERT(extents.rows <= kMaxM); HWY_DASSERT(extents.cols <= kMaxK); - const size_t stride = StrideForCyclicOffsets(extents.cols); - return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride); + const size_t stride = + StrideForCyclicOffsets(extents.cols, allocator.Quantum()); + return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride); } RowPtrD Partial() const { return partial_; } @@ -431,13 +430,15 @@ class MMConfig { static_assert(sizeof(MMConfig) == 32); // for faster indexing #pragma pack(pop) -std::vector MMCandidates(size_t M, size_t K, size_t N, - size_t sizeof_TC, size_t max_mr, size_t nr, +std::vector MMCandidates(const Allocator2& allocator, size_t M, + size_t K, size_t N, size_t sizeof_TC, + size_t max_mr, size_t nr, const IndexRangePartition& ranges_np, bool print_config); // State machine for choosing the best `TConfig`, which is `MMConfig` for the // main MatMul autotuner. +// TODO: replace with hwy/auto_tune.h. template class MMAutoTune { public: @@ -560,11 +561,11 @@ class MMKeys { } // Must only be called if not already present in `Keys()`. - void Append(Key key) { + void Append(Key key, const Allocator2& allocator) { // Dynamic allocation because the test checks many more dimensions than // would be reasonable to pre-allocate. DIY for alignment and padding. if (HWY_UNLIKELY(num_unique_ >= capacity_)) { - const size_t NU64 = Allocator::VectorBytes() / sizeof(Key); + const size_t NU64 = allocator.VectorBytes() / sizeof(Key); // Start at one vector so the size is always a multiple of N. if (HWY_UNLIKELY(capacity_ == 0)) { capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below @@ -604,10 +605,12 @@ struct MMPerKey { MMAutoTune autotune_par_a[MMParallel::kMaxPackages]; }; -// Stores state shared across MatMul calls. Non-copyable. +// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive +// `MatMulEnv`. struct MatMulEnv { - explicit MatMulEnv(const BoundedTopology& topology, NestedPools& pools); + explicit MatMulEnv(ThreadingContext2& ctx); + ThreadingContext2& ctx; bool have_timer_stop = false; // Enable binding: disabled in Gemma until tensors support it, enabled in @@ -684,8 +687,9 @@ struct MMZone { // `ofs` required for compressed T. template struct ConstMat { - ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0) - : ptr(ptr), extents(extents), stride(stride), ofs(ofs) { + ConstMat() = default; + ConstMat(const T* ptr, Extents2D extents, size_t stride) + : ptr(ptr), extents(extents), stride(stride), ofs(0) { HWY_DASSERT(ptr != nullptr); HWY_DASSERT(stride >= extents.cols); } @@ -717,15 +721,17 @@ struct ConstMat { float scale = 1.0f; // Offset to add to `ptr`; separate because T=NuqStream does not support - // pointer arithmetic. + // pointer arithmetic. This is in units of weights, and does not have anything + // to do with the interleaved NUQ tables. It should be computed via `Row()` + // to take into account the stride. size_t ofs; }; // For deducing T. template -ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride, - size_t ofs = 0) { - return ConstMat(ptr, extents, stride, ofs); +ConstMat MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, + size_t stride) { + return ConstMat(ptr, extents, stride); } // For A argument to MatMul (activations). @@ -739,21 +745,21 @@ ConstMat ConstMatFromBatch(size_t batch_size, } template -ConstMat ConstMatFromWeights(const MatPtrT& m, size_t ofs = 0) { +ConstMat ConstMatFromWeights(const MatPtrT& m) { ConstMat mat = - MakeConstMat(const_cast(m.data()), m.Extents(), m.Stride(), ofs); - mat.scale = m.scale(); + MakeConstMat(const_cast(m.Row(0)), m.Extents(), m.Stride()); + mat.scale = m.Scale(); return mat; } template -void BindB(size_t N, size_t sizeof_TC, const ConstMat& B, - MMParallel& parallel) { - if (!Allocator::ShouldBind()) return; +void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC, + const ConstMat& B, MMParallel& parallel) { + if (!allocator.ShouldBind()) return; const IndexRangePartition ranges_np = parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR); - const size_t quantum = Allocator::QuantumBytes() / sizeof(TB); + const size_t quantum = allocator.Quantum(); for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) { const IndexRange& rows_b = ranges_np.Range(pkg_idx); const size_t node = parallel.Node(pkg_idx); @@ -765,7 +771,7 @@ void BindB(size_t N, size_t sizeof_TC, const ConstMat& B, begin = hwy::RoundUpTo(begin, quantum); end = hwy::RoundDownTo(end, quantum); if (HWY_LIKELY(begin != end)) { - Allocator::BindMemory(reinterpret_cast(begin), end - begin, node); + allocator.BindMemory(reinterpret_cast(begin), end - begin, node); } } } diff --git a/ops/matmul_test.cc b/ops/matmul_test.cc index aaf3bc1..552f3d9 100644 --- a/ops/matmul_test.cc +++ b/ops/matmul_test.cc @@ -15,29 +15,25 @@ // End to end test of MatMul, comparing against a reference implementation. -#include "hwy/detect_compiler_arch.h" +#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep #ifndef HWY_DISABLED_TARGETS // Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require -// double-precision support. +// double-precision support, and older x86 to speed up builds. #if HWY_ARCH_ARM_V7 #define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON) #else -#define HWY_DISABLED_TARGETS HWY_SCALAR +#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4) #endif #endif #include #include -#include - -#include "compression/compress.h" #include "compression/shared.h" #include "ops/matmul.h" -#include "util/allocator.h" #include "util/basics.h" -#include "util/threading.h" -#include "hwy/base.h" +#include "util/mat.h" +#include "util/threading_context.h" #include "hwy/contrib/thread_pool/thread_pool.h" // clang-format off @@ -48,9 +44,9 @@ #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" +#include "compression/test_util-inl.h" #include "ops/dot-inl.h" #include "ops/matmul-inl.h" -#include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -60,57 +56,6 @@ extern int64_t first_target; namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -using FloatPtr = hwy::AlignedFreeUniquePtr; - -template -using MatStoragePtr = std::unique_ptr>; - -// Generates inputs: deterministic, within max SfpStream range. -template -MatStoragePtr GenerateMat(const Extents2D& extents, - hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - auto mat = - std::make_unique>("mat", extents.rows, extents.cols); - FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - HWY_ASSERT(content); - const float scale = SfpStream::kMax / (mat->NumElements()); - pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { - for (size_t c = 0; c < extents.cols; c++) { - float f = static_cast(r * extents.cols + c) * scale; - if ((r + c) & 1) f = -f; // Also generate some negative values. - content[r * extents.cols + c] = f; - } - }); - - CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); - mat->set_scale(0.6f); // Arbitrary value, different from 1. - return mat; -} - -// extents describes the transposed matrix. -template -MatStoragePtr GenerateTransposedMat(const Extents2D extents, - hwy::ThreadPool& pool) { - gcpp::CompressWorkingSet ws; - auto mat = - std::make_unique>("trans", extents.rows, extents.cols); - FloatPtr content = hwy::AllocateAligned(mat->NumElements()); - const float scale = SfpStream::kMax / (mat->NumElements()); - pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) { - for (size_t c = 0; c < extents.cols; c++) { - float f = static_cast(c * extents.rows + r) * scale; - if ((r + c) & 1) f = -f; // Also generate some negative values. - content[r * extents.cols + c] = f; - } - }); - - CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool); - // Arbitrary value, different from 1, must match GenerateMat. - mat->set_scale(0.6f); - return mat; -} - // Returns 1-norm, used for estimating tolerable numerical differences. double MaxRowAbsSum(const RowVectorBatch& a) { double max_row_abs_sum = 0.0; @@ -141,16 +86,19 @@ float MaxAbs(const RowVectorBatch& a) { template void AssertClose(const ConstMat& A, const ConstMat& B, const RowPtr& C_slow, const RowPtr& C, int line) { + const Allocator2& allocator = ThreadingContext2::Get().allocator; const hn::ScalableTag df; const size_t cols = A.extents.cols; const size_t B_rows = B.extents.rows; // Round up for DecompressAndZeroPad. - RowVectorBatch a_batch = AllocateAlignedRows(A.extents); - RowVectorBatch b_trans_batch = AllocateAlignedRows(B.extents); + RowVectorBatch a_batch = + AllocateAlignedRows(allocator, A.extents); + RowVectorBatch b_trans_batch = + AllocateAlignedRows(allocator, B.extents); RowVectorBatch c_batch = - AllocateAlignedRows(Extents2D(A.extents.rows, B_rows)); + AllocateAlignedRows(allocator, Extents2D(A.extents.rows, B_rows)); RowVectorBatch c_slow_batch = - AllocateAlignedRows(Extents2D(A.extents.rows, B_rows)); + AllocateAlignedRows(allocator, Extents2D(A.extents.rows, B_rows)); HWY_ASSERT(A.ofs == 0 && B.ofs == 0); for (size_t m = 0; m < A.extents.rows; ++m) { DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0, @@ -224,7 +172,7 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, const IndexRange all_rows_c(0, A.Extents().rows); const IndexRange all_cols_c(0, C.Cols()); - NestedPools& pools = env.parallel.Pools(); + NestedPools& pools = env.ctx.pools; hwy::ThreadPool& all_packages = pools.AllPackages(); const IndexRangePartition get_row_c = StaticPartition(all_rows_c, all_packages.NumWorkers(), 1); @@ -232,7 +180,7 @@ HWY_INLINE void MatMulSlow(const ConstMat A, const ConstMat B, get_row_c, all_packages, [&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR { hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx); - const size_t multiple = Allocator::QuantumBytes() / sizeof(TB); + const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB); const IndexRangePartition get_col_c = StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple); ParallelizeOneRange( @@ -262,7 +210,8 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents, template void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, MatMulEnv& env, int line) { - hwy::ThreadPool& pool = env.parallel.Pools().Pool(); + const Allocator2& allocator = env.ctx.allocator; + hwy::ThreadPool& pool = env.ctx.pools.Pool(); fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", rows_ac, cols_a_rows_b, cols_bc, add, TypeName(), TypeName(), TypeName()); @@ -274,24 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add, const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed const Extents2D C_extents(rows_ac, cols_bc); - MatStoragePtr a = GenerateMat(A_extents, pool); - MatStoragePtr b_trans = GenerateTransposedMat(B_extents, pool); - RowVectorBatch c_slow_batch = AllocateAlignedRows(C_extents); - RowVectorBatch c_batch = AllocateAlignedRows(C_extents); - HWY_ASSERT(a && b_trans); + MatStorageT a(GenerateMat(A_extents, pool)); + MatStorageT b_trans(GenerateTransposedMat(B_extents, pool)); + RowVectorBatch c_slow_batch = + AllocateAlignedRows(allocator, C_extents); + RowVectorBatch c_batch = AllocateAlignedRows(allocator, C_extents); - std::unique_ptr> add_storage; - if (add) { - add_storage = GenerateMat(Extents2D(1, cols_bc), pool); - HWY_ASSERT(add_storage); - add_storage->set_scale(1.0f); - } + MatStorageT add_storage = + add ? GenerateMat(Extents2D(1, cols_bc), pool) + : MatStorageT("add", Extents2D(), MatPadding::kPacked); + add_storage.SetScale(1.0f); - const auto A = ConstMatFromWeights(*a); - const auto B = ConstMatFromWeights(*b_trans); - const float* add_row = add ? add_storage->data_scale1() : nullptr; - const RowPtr C_slow = RowPtrFromBatch(c_slow_batch); - const RowPtr C = RowPtrFromBatch(c_batch); + const auto A = ConstMatFromWeights(a); + const auto B = ConstMatFromWeights(b_trans); + const float* add_row = add ? add_storage.PackedScale1() : nullptr; + const RowPtr C_slow = RowPtrFromBatch(allocator, c_slow_batch); + const RowPtr C = RowPtrFromBatch(allocator, c_batch); MatMulSlow(A, B, add_row, env, C_slow); // A few reps to get coverage of the various autotuned code paths. @@ -312,22 +259,24 @@ void TestTiny() { if (HWY_TARGET != first_target) return; for (size_t max_packages : {1, 2}) { - const BoundedTopology topology(BoundedSlice(0, max_packages)); - Allocator::Init(topology, /*enable_bind=*/true); - const size_t max_threads = 0; // no limit - NestedPools pools(topology, max_threads, Tristate::kDefault); + ThreadingContext2::ThreadHostileInvalidate(); + ThreadingArgs threading_args; + threading_args.bind = Tristate::kTrue; + threading_args.max_packages = max_packages; + ThreadingContext2::SetArgs(threading_args); + MatMulEnv env(ThreadingContext2::Get()); + NestedPools& pools = env.ctx.pools; + #if GEMMA_DISABLE_TOPOLOGY if (max_packages == 2) break; // we only have one package #else // If less than the limit, we have already tested all num_packages. - if (topology.FullTopology().packages.size() < max_packages) break; + if (env.ctx.topology.FullTopology().packages.size() < max_packages) break; #endif fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages, - topology.TopologyString(), pools.PinString()); + env.ctx.topology.TopologyString(), pools.PinString()); - Tristate use_spinning = Tristate::kDefault; - pools.MaybeStartSpinning(use_spinning); - MatMulEnv env(topology, pools); + pools.MaybeStartSpinning(threading_args.spin); for (size_t M = 1; M <= 12; ++M) { for (size_t K = 1; K <= 64; K *= 2) { @@ -336,7 +285,7 @@ void TestTiny() { } } } - pools.MaybeStopSpinning(use_spinning); + pools.MaybeStopSpinning(threading_args.spin); } } @@ -347,12 +296,13 @@ void TestAllMatMul() { return; } - const BoundedTopology topology; - Allocator::Init(topology, /*enable_bind=*/true); - NestedPools pools(topology); - Tristate use_spinning = Tristate::kDefault; - pools.MaybeStartSpinning(use_spinning); - MatMulEnv env(topology, pools); + ThreadingContext2::ThreadHostileInvalidate(); + ThreadingArgs threading_args; + threading_args.bind = Tristate::kTrue; + ThreadingContext2::SetArgs(threading_args); + MatMulEnv env(ThreadingContext2::Get()); + NestedPools& pools = env.ctx.pools; + pools.MaybeStartSpinning(threading_args.spin); // Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand. TestMatMul(1, 2048, 512, /*add=*/false, env, __LINE__); @@ -417,6 +367,8 @@ void TestAllMatMul() { TestMatMul(1, 128, 32, /*add=*/true, env, __LINE__); TestMatMul(1, 128, 32, /*add=*/false, env, __LINE__); TestMatMul(1, 128, 32, /*add=*/true, env, __LINE__); + + pools.MaybeStopSpinning(threading_args.spin); } // NOLINTNEXTLINE(google-readability-namespace-comments) diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 7ad56e7..728ce41 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -50,8 +50,7 @@ template HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned, size_t num) { const hn::ScalableTag d; - return w.scale() * Dot(d, MakeConstSpan(w.data(), w.NumElements()), w_ofs, - vec_aligned, num); + return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num); } // Simple version without tiling nor threading, but two offsets/outputs and diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 52f72bd..6132620 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -27,12 +27,13 @@ #include // std::enable_if_t #include -#include "compression/compress.h" +#include "util/allocator.h" #include "util/basics.h" // TokenAndProb +#include "util/mat.h" +#include "util/threading_context.h" #include "hwy/base.h" #include "hwy/contrib/sort/order.h" #include "hwy/contrib/sort/vqsort.h" -#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/detect_targets.h" #include "hwy/profiler.h" #endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_ @@ -807,12 +808,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK( // Each output row is the average of a 4x4 block of input rows template RowVectorBatch AvgPool4x4(RowVectorBatch& input) { - Extents2D extents = input.Extents(); + const Allocator2& allocator = ThreadingContext2::Get().allocator; + const Extents2D extents = input.Extents(); // Input validation HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows // Create output with 256 rows and same number of columns const size_t out_rows = 256; // 16 * 16 = 256 output rows - RowVectorBatch result(Extents2D{out_rows, extents.cols}); + RowVectorBatch result(allocator, Extents2D(out_rows, extents.cols)); const size_t input_dim = 64; // Input is 64×64 const size_t output_dim = 16; // Output is 16×16 for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) { diff --git a/ops/ops.h b/ops/ops.h index 6c243da..0f99963 100644 --- a/ops/ops.h +++ b/ops/ops.h @@ -21,14 +21,16 @@ #include #include "util/allocator.h" +#include "util/mat.h" #include "hwy/base.h" namespace gcpp { static inline HWY_MAYBE_UNUSED RowVectorBatch CreateInvTimescale( - size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) { + const Allocator2& allocator, size_t qkv_dim, bool half_rope, + double base_frequency = 10000.0) { const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim; - RowVectorBatch inv_timescale(Extents2D(1, rope_dim / 2)); + RowVectorBatch inv_timescale(allocator, Extents2D(1, rope_dim / 2)); for (size_t dim = 0; dim < rope_dim / 2; ++dim) { const double freq_exponents = static_cast(2 * dim) / static_cast(rope_dim); diff --git a/ops/ops_test.cc b/ops/ops_test.cc index 5414138..b44c3f7 100644 --- a/ops/ops_test.cc +++ b/ops/ops_test.cc @@ -31,14 +31,12 @@ #include #include -#include "compression/compress.h" // BF16 #include "gemma/common.h" -#include "gemma/configs.h" #include "util/allocator.h" -#include "util/app.h" +#include "util/basics.h" // BF16 +#include "util/mat.h" // RowVectorBatch #include "util/test_util.h" -#include "util/threading.h" -#include "hwy/base.h" +#include "util/threading_context.h" #include "hwy/tests/hwy_gtest.h" // clang-format off @@ -388,13 +386,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy( } void TestRopeAndMulBy() { - AppArgs app; - BoundedTopology topology = CreateTopology(app); - NestedPools pools = CreatePools(topology, app); + const Allocator2& allocator = ThreadingContext2::Get().allocator; ModelConfig config = ConfigFromModel(Model::GEMMA2_9B); int dim_qkv = config.layer_configs[0].qkv_dim; - RowVectorBatch x(Extents2D(1, dim_qkv)); + RowVectorBatch x(allocator, Extents2D(1, dim_qkv)); std::mt19937 gen; gen.seed(0x12345678); @@ -412,8 +408,8 @@ void TestRopeAndMulBy() { std::vector qactual(dim_qkv); std::vector kexpected(dim_qkv); std::vector kactual(dim_qkv); - RowVectorBatch inv_timescale = gcpp::CreateInvTimescale( - config.layer_configs[0].qkv_dim, + RowVectorBatch inv_timescale = CreateInvTimescale( + allocator, config.layer_configs[0].qkv_dim, config.layer_configs[0].post_qk == PostQKType::HalfRope); // Assert VectorizedRope computation is same as regular rope at different pos. for (int pos = 1; pos < 500; pos++) { diff --git a/paligemma/paligemma_test.cc b/paligemma/paligemma_test.cc index fe7fcc9..398b067 100644 --- a/paligemma/paligemma_test.cc +++ b/paligemma/paligemma_test.cc @@ -49,8 +49,8 @@ class PaliGemmaTest : public ::testing::Test { }; void PaliGemmaTest::InitVit(const std::string& path) { - ASSERT_NE(s_env->GetModel(), nullptr); - Gemma& model = *(s_env->GetModel()); + ASSERT_NE(s_env->GetGemma(), nullptr); + Gemma& model = *(s_env->GetGemma()); image_tokens_ = ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len, model.GetModelConfig().model_dim)); @@ -64,7 +64,7 @@ void PaliGemmaTest::InitVit(const std::string& path) { } std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ - Gemma& model = *(s_env->GetModel()); + Gemma& model = *(s_env->GetGemma()); s_env->MutableGen().seed(0x12345678); RuntimeConfig runtime_config = {.max_generated_tokens = 512, .gen = &s_env->MutableGen(), @@ -92,7 +92,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{ } void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) { - ASSERT_NE(s_env->GetModel(), nullptr); + ASSERT_NE(s_env->GetGemma(), nullptr); std::string path = "paligemma/testdata/image.ppm"; InitVit(path); for (size_t i = 0; i < num_questions; ++i) { @@ -104,7 +104,7 @@ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) { } TEST_F(PaliGemmaTest, General) { - ASSERT_NE(s_env->GetModel(), nullptr); + ASSERT_NE(s_env->GetGemma(), nullptr); static const char* kQA_3B_mix_224[][2] = { {"describe this image", "A large building with two towers stands tall on the water's edge."}, @@ -124,7 +124,7 @@ TEST_F(PaliGemmaTest, General) { }; const char* (*qa)[2]; size_t num; - switch (s_env->GetModel()->Info().model) { + switch (s_env->GetGemma()->Info().model) { case Model::PALIGEMMA_224: qa = kQA_3B_mix_224; num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]); @@ -135,7 +135,7 @@ TEST_F(PaliGemmaTest, General) { break; default: FAIL() << "Unsupported model: " - << s_env->GetModel()->GetModelConfig().model_name; + << s_env->GetGemma()->GetModelConfig().model_name; break; } TestQuestions(qa, num); diff --git a/python/BUILD.bazel b/python/BUILD.bazel index 29de6bc..1298473 100644 --- a/python/BUILD.bazel +++ b/python/BUILD.bazel @@ -21,12 +21,12 @@ pybind_extension( name = "gemma", srcs = ["gemma_py.cc"], deps = [ - "//:app", + "//:allocator", "//:benchmark_helper", + "//:gemma_args", "//:gemma_lib", "//compression:sfp", "@highway//:hwy", - "@highway//:thread_pool", ], ) diff --git a/python/gemma_py.cc b/python/gemma_py.cc index a7ce022..0791188 100644 --- a/python/gemma_py.cc +++ b/python/gemma_py.cc @@ -32,9 +32,9 @@ #include "compression/shared.h" #include "evals/benchmark_helper.h" #include "gemma/gemma.h" -#include "util/app.h" +#include "gemma/gemma_args.h" +#include "util/allocator.h" #include "hwy/base.h" -#include "hwy/contrib/thread_pool/thread_pool.h" namespace py = pybind11; @@ -48,8 +48,9 @@ static void RemoveTrailingZeros(std::vector &vec) { class GemmaModel { public: GemmaModel(const gcpp::LoaderArgs& loader, - const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app) - : gemma_(loader, inference, app), last_prob_(0.0f) {} + const gcpp::InferenceArgs& inference, + const gcpp::ThreadingArgs& threading) + : gemma_(threading, loader, inference), last_prob_(0.0f) {} // Generates a single example, given a prompt and a callback to stream the // generated tokens. @@ -168,7 +169,8 @@ class GemmaModel { // Generate* will use this image. Throws an error for other models. void SetImage(const py::array_t& image) { - gcpp::Gemma& model = *(gemma_.GetModel()); + const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator; + gcpp::Gemma& model = *(gemma_.GetGemma()); if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) { throw std::invalid_argument("Not a PaliGemma model."); } @@ -183,9 +185,9 @@ class GemmaModel { c_image.Set(height, width, ptr); const size_t image_size = model.GetModelConfig().vit_config.image_size; c_image.Resize(image_size, image_size); - image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D( - model.GetModelConfig().vit_config.seq_len, - model.GetModelConfig().model_dim)); + image_tokens_ = gcpp::ImageTokens( + allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len, + model.GetModelConfig().model_dim)); gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(), .verbosity = 0}; model.GenerateImageTokens(runtime_config, c_image, image_tokens_); @@ -199,7 +201,7 @@ class GemmaModel { if (image_tokens_.Cols() == 0) { throw std::invalid_argument("No image set."); } - gcpp::Gemma& model = *(gemma_.GetModel()); + gcpp::Gemma& model = *(gemma_.GetGemma()); gemma_.MutableGen().seed(seed); gcpp::RuntimeConfig& config = gemma_.MutableConfig(); config.max_generated_tokens = max_generated_tokens; @@ -247,7 +249,7 @@ class GemmaModel { return gemma_.StringFromTokens(token_ids); } - bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; } + bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; } private: gcpp::GemmaEnv gemma_; @@ -267,7 +269,7 @@ PYBIND11_MODULE(gemma, mod) { loader.weight_type_str = weight_type; gcpp::InferenceArgs inference; inference.max_generated_tokens = 512; - gcpp::AppArgs app; + gcpp::ThreadingArgs app; app.max_threads = max_threads; auto gemma = std::make_unique(loader, inference, app); diff --git a/util/allocator.cc b/util/allocator.cc index 20d65ad..b5b6278 100644 --- a/util/allocator.cc +++ b/util/allocator.cc @@ -130,233 +130,6 @@ size_t DetectTotalMiB(size_t page_bytes) { } // namespace -static size_t line_bytes_; -static size_t vector_bytes_; -static size_t step_bytes_; -static size_t quantum_bytes_; -static size_t quantum_steps_; -static size_t l1_bytes_; -static size_t l2_bytes_; -static size_t l3_bytes_; -static bool should_bind_ = false; - -size_t Allocator::LineBytes() { return line_bytes_; } -size_t Allocator::VectorBytes() { return vector_bytes_; } -size_t Allocator::StepBytes() { return step_bytes_; } -size_t Allocator::QuantumBytes() { return quantum_bytes_; } -size_t Allocator::QuantumSteps() { return quantum_steps_; } -size_t Allocator::L1Bytes() { return l1_bytes_; } -size_t Allocator::L2Bytes() { return l2_bytes_; } -size_t Allocator::L3Bytes() { return l3_bytes_; } -bool Allocator::ShouldBind() { return should_bind_; } - -void Allocator::Init(const BoundedTopology& topology, bool enable_bind) { - line_bytes_ = DetectLineBytes(); - vector_bytes_ = hwy::VectorBytes(); - step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_); - quantum_bytes_ = step_bytes_; // may overwrite below - - const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0); - if (const hwy::Cache* caches = hwy::DataCaches()) { - l1_bytes_ = caches[1].size_kib << 10; - l2_bytes_ = caches[2].size_kib << 10; - l3_bytes_ = (caches[3].size_kib << 10) * caches[3].cores_sharing; - } else { // Unknown, make reasonable assumptions. - l1_bytes_ = 32 << 10; - l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10; - } - if (l3_bytes_ == 0) { - l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10; - } - - // Prerequisites for binding: - // - supported by the OS (currently Linux only), - // - the page size is known and 'reasonably small', preferably less than - // a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB. - // - we successfully detected topology and there are multiple nodes; - // - there are multiple packages, because we shard by package_idx. - if constexpr (GEMMA_BIND) { - const size_t page_bytes = DetectPageSize(); - if ((page_bytes != 0 && page_bytes <= 16 * 1024) && - topology.NumNodes() > 1 && topology.NumPackages() > 1) { - if (enable_bind) { - // Ensure pages meet the alignment requirements of `AllocBytes`. - HWY_ASSERT(page_bytes >= quantum_bytes_); - quantum_bytes_ = page_bytes; - // Ensure MaxQuantumBytes() is an upper bound. - HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_); - quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes()); - should_bind_ = true; - } else { - HWY_WARN( - "Multiple sockets but binding disabled. This reduces speed; " - "set or remove enable_bind to avoid this warning."); - } - } - } - - HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0); - quantum_steps_ = quantum_bytes_ / step_bytes_; -} - -Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) { - // If we are not binding, the Highway allocator is cheaper than `mmap`, and - // defends against 2K aliasing. - if (!should_bind_) { - // Perf warning if Highway's alignment is less than we want. - if (HWY_ALIGNMENT < QuantumBytes()) { - HWY_WARN( - "HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines " - "are huge, enable GEMMA_BIND to avoid this warning.", - HWY_ALIGNMENT, QuantumBytes()); - } - auto p = hwy::AllocateAligned(bytes); - // The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the - // alignment scheme in aligned_allocator.cc and does not work for - // already-aligned pointers as returned by `mmap`, hence we wrap the Highway - // pointer in our own deleter. - auto call_free = [](void* ptr, size_t /*bytes*/) { - hwy::FreeAlignedBytes(ptr, nullptr, nullptr); - }; - return PtrAndDeleter{p.release(), DeleterFree(call_free, bytes)}; - } - - // Binding, or large vector/cache line size: use platform-specific allocator. - -#if HWY_OS_LINUX && !defined(__ANDROID_API__) - // `move_pages` is documented to require an anonymous/private mapping or - // `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`. - // `Init` verified that the page size is a multiple of `QuantumBytes()`. - const int prot = PROT_READ | PROT_WRITE; - const int flags = MAP_ANONYMOUS | MAP_PRIVATE; - const int fd = -1; - // Encourage transparent hugepages by rounding up to a multiple of 2 MiB. - bytes = hwy::RoundUpTo(bytes, 2ull << 20); - void* p = mmap(0, bytes, prot, flags, fd, off_t{0}); - if (p == MAP_FAILED) p = nullptr; - const auto call_munmap = [](void* ptr, size_t bytes) { - const int ret = munmap(ptr, bytes); - HWY_ASSERT(ret == 0); - }; - return PtrAndDeleter{p, DeleterFree(call_munmap, bytes)}; -#elif HWY_OS_WIN - const auto call_free = [](void* ptr, size_t) { _aligned_free(ptr); }; - const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_); - return PtrAndDeleter{_aligned_malloc(bytes, alignment), - DeleterFree(call_free, bytes)}; -#else - return PtrAndDeleter{nullptr, DeleterFree(nullptr, 0)}; -#endif -} - -#if GEMMA_BIND && HWY_OS_LINUX - -using Ret = long; // NOLINT(runtime/int) -using UL = unsigned long; // NOLINT(runtime/int) -static constexpr size_t ULBits = sizeof(UL) * 8; - -// Calling via syscall avoids a dependency on libnuma. -struct SyscallWrappers { - static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes, - unsigned flags) { - MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL)); - return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags); - }; - - static Ret move_pages(int pid, UL count, void** pages, const int* nodes, - int* status, int flags) { - MaybeCheckInitialized(pages, count * sizeof(void*)); - MaybeCheckInitialized(nodes, count * sizeof(int)); - MaybeCheckInitialized(status, count * sizeof(int)); - return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags); - } - - static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr, - unsigned flags) { - return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags); - } -}; - -// Returns the number of pages that are currently busy (hence not yet moved), -// and warns if there are any other reasons for not moving a page. Note that -// `move_pages` can return 0 regardless of whether all pages were moved. -size_t CountBusyPages(size_t num_pages, size_t node, void** pages, - const int* status) { - size_t num_busy = 0; - for (size_t i = 0; i < num_pages; ++i) { - if (status[i] == -EBUSY) { - ++num_busy; - } else if (status[i] != static_cast(node)) { - static std::atomic_flag first = ATOMIC_FLAG_INIT; - if (!first.test_and_set()) { - HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).", - status[i], i, pages[i], node, errno); - } - } - } - return num_busy; -} - -bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) { - HWY_DASSERT(should_bind_); - constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" - - if constexpr (HWY_IS_DEBUG_BUILD) { - // Ensure the requested `node` is allowed. - UL nodes[kMaxNodes / 64] = {0}; - const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED - HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes, - nullptr, flags) == 0); - HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64))); - } - - // Avoid mbind because it does not report why it failed, which is most likely - // because pages are busy, in which case we want to know which. - - // `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set. - const unsigned flags = 2; // MPOL_MF_MOVE - HWY_ASSERT(bytes % quantum_bytes_ == 0); - const size_t num_pages = bytes / quantum_bytes_; - std::vector pages; - pages.reserve(num_pages); - for (size_t i = 0; i < num_pages; ++i) { - pages.push_back(static_cast(ptr) + i * quantum_bytes_); - // Ensure the page is faulted in to prevent `move_pages` from failing, - // because freshly allocated pages may be mapped to a shared 'zero page'. - hwy::ZeroBytes(pages.back(), 8); - } - std::vector nodes(num_pages, node); - std::vector status(num_pages, static_cast(kMaxNodes)); - - Ret ret = SyscallWrappers::move_pages( - /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); - if (ret < 0) { - HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr, - bytes, node, errno, status[0]); - return false; - } - - const size_t num_busy = - CountBusyPages(num_pages, node, pages.data(), status.data()); - if (HWY_UNLIKELY(num_busy != 0)) { - // Trying again is usually enough to succeed. - hwy::NanoSleep(5000); - (void)SyscallWrappers::move_pages( - /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); - const size_t still_busy = - CountBusyPages(num_pages, node, pages.data(), status.data()); - if (HWY_UNLIKELY(still_busy != 0)) { - HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.", - still_busy, num_busy); - } - } - return true; -} - -#else -bool Allocator::BindMemory(void*, size_t, size_t) { return false; } -#endif // GEMMA_BIND && HWY_OS_LINUX - Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) { line_bytes_ = DetectLineBytes(); vector_bytes_ = hwy::VectorBytes(); @@ -428,7 +201,7 @@ size_t Allocator2::FreeMiB() const { #endif } -Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const { +AlignedPtr2 Allocator2::AllocBytes(size_t bytes) const { // If we are not binding, the Highway allocator is cheaper than `mmap`, and // defends against 2K aliasing. if (!should_bind_) { @@ -444,9 +217,10 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const { // alignment scheme in aligned_allocator.cc and does not work for // already-aligned pointers as returned by `mmap`, hence we wrap the Highway // pointer in our own deleter. - return PtrAndDeleter{p.release(), DeleterFunc2([](void* ptr) { - hwy::FreeAlignedBytes(ptr, nullptr, nullptr); - })}; + return AlignedPtr2(p.release(), DeleterFunc2([](void* ptr) { + hwy::FreeAlignedBytes(ptr, nullptr, + nullptr); + })); } // Binding, or large vector/cache line size: use platform-specific allocator. @@ -460,20 +234,126 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const { const int fd = -1; void* p = mmap(0, bytes, prot, flags, fd, off_t{0}); if (p == MAP_FAILED) p = nullptr; - return PtrAndDeleter{p, DeleterFunc2([bytes](void* ptr) { - HWY_ASSERT(munmap(ptr, bytes) == 0); - })}; + return AlignedPtr2(static_cast(p), + DeleterFunc2([bytes](void* ptr) { + HWY_ASSERT(munmap(ptr, bytes) == 0); + })); #elif HWY_OS_WIN const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_); - return PtrAndDeleter{_aligned_malloc(bytes, alignment), - DeleterFunc2([](void* ptr) { _aligned_free(ptr); })}; + return AlignedPtr2( + static_cast(_aligned_malloc(bytes, alignment)), + DeleterFunc2([](void* ptr) { _aligned_free(ptr); })); #else - return PtrAndDeleter{nullptr, DeleterFunc2()}; + return AlignedPtr2(nullptr, DeleterFunc2()); #endif } -bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const { - return Allocator::BindMemory(ptr, bytes, node); +#if GEMMA_BIND && HWY_OS_LINUX + +using Ret = long; // NOLINT(runtime/int) +using UL = unsigned long; // NOLINT(runtime/int) +static constexpr size_t ULBits = sizeof(UL) * 8; + +// Calling via syscall avoids a dependency on libnuma. +struct SyscallWrappers { + static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes, + unsigned flags) { + MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL)); + return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags); + }; + + static Ret move_pages(int pid, UL count, void** pages, const int* nodes, + int* status, int flags) { + MaybeCheckInitialized(pages, count * sizeof(void*)); + MaybeCheckInitialized(nodes, count * sizeof(int)); + MaybeCheckInitialized(status, count * sizeof(int)); + return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags); + } + + static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr, + unsigned flags) { + return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags); + } +}; + +// Returns the number of pages that are currently busy (hence not yet moved), +// and warns if there are any other reasons for not moving a page. Note that +// `move_pages` can return 0 regardless of whether all pages were moved. +size_t CountBusyPages(size_t num_pages, size_t node, void** pages, + const int* status) { + size_t num_busy = 0; + for (size_t i = 0; i < num_pages; ++i) { + if (status[i] == -EBUSY) { + ++num_busy; + } else if (status[i] != static_cast(node)) { + static std::atomic_flag first = ATOMIC_FLAG_INIT; + if (!first.test_and_set()) { + HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).", + status[i], i, pages[i], node, errno); + } + } + } + return num_busy; } +bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const { + HWY_DASSERT(should_bind_); + constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough" + + if constexpr (HWY_IS_DEBUG_BUILD) { + // Ensure the requested `node` is allowed. + UL nodes[kMaxNodes / 64] = {0}; + const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED + HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes, + nullptr, flags) == 0); + HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64))); + } + + // Avoid mbind because it does not report why it failed, which is most likely + // because pages are busy, in which case we want to know which. + + // `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set. + const unsigned flags = 2; // MPOL_MF_MOVE + HWY_ASSERT(bytes % quantum_bytes_ == 0); + const size_t num_pages = bytes / quantum_bytes_; + std::vector pages; + pages.reserve(num_pages); + for (size_t i = 0; i < num_pages; ++i) { + pages.push_back(static_cast(ptr) + i * quantum_bytes_); + // Ensure the page is faulted in to prevent `move_pages` from failing, + // because freshly allocated pages may be mapped to a shared 'zero page'. + hwy::ZeroBytes(pages.back(), 8); + } + std::vector nodes(num_pages, node); + std::vector status(num_pages, static_cast(kMaxNodes)); + + Ret ret = SyscallWrappers::move_pages( + /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); + if (ret < 0) { + HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr, + bytes, node, errno, status[0]); + return false; + } + + const size_t num_busy = + CountBusyPages(num_pages, node, pages.data(), status.data()); + if (HWY_UNLIKELY(num_busy != 0)) { + // Trying again is usually enough to succeed. + hwy::NanoSleep(5000); + (void)SyscallWrappers::move_pages( + /*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags); + const size_t still_busy = + CountBusyPages(num_pages, node, pages.data(), status.data()); + if (HWY_UNLIKELY(still_busy != 0)) { + HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.", + still_busy, num_busy); + } + } + return true; +} + +#else +bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; } +#endif // GEMMA_BIND && HWY_OS_LINUX + } // namespace gcpp diff --git a/util/allocator.h b/util/allocator.h index b5d59bb..a0e726c 100644 --- a/util/allocator.h +++ b/util/allocator.h @@ -30,307 +30,8 @@ #include "hwy/base.h" // IWYU pragma: end_exports -#include "hwy/aligned_allocator.h" - namespace gcpp { -// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The -// `bytes` argument is required for the latter. -using FreeFunc = void (*)(void* mem, size_t bytes); - -// Custom deleter for std::unique_ptr that calls `FreeFunc`. T is POD. -class DeleterFree { - public: - // `MatStorageT` requires this to be default-constructible. - DeleterFree() : free_func_(nullptr), bytes_(0) {} - DeleterFree(FreeFunc free_func, size_t bytes) - : free_func_(free_func), bytes_(bytes) {} - - template - void operator()(T* p) const { - free_func_(p, bytes_); - } - - private: - FreeFunc free_func_; - size_t bytes_; -}; - -// Wrapper that also calls the destructor for non-POD T. -class DeleterDtor { - public: - DeleterDtor() {} - DeleterDtor(size_t num, DeleterFree free) : num_(num), free_(free) {} - - template - void operator()(T* p) const { - for (size_t i = 0; i < num_; ++i) { - p[i].~T(); - } - free_(p); - } - - private: - size_t num_; // not the same as free_.bytes_ / sizeof(T)! - DeleterFree free_; -}; - -// Unique (move-only) pointer to an aligned array of POD T. -template -using AlignedPtr = std::unique_ptr; -// Unique (move-only) pointer to an aligned array of non-POD T. -template -using AlignedClassPtr = std::unique_ptr; - -// Both allocation, binding, and row accessors depend on the sizes of memory -// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we -// use `Monostate` (static members). -class Allocator { - public: - // Must be called at least once before any other function. Not thread-safe, - // hence only call this from the main thread. - // TODO: remove enable_bind once Gemma tensors support binding. - static void Init(const BoundedTopology& topology, bool enable_bind = false); - - // Bytes per cache line, or a reasonable guess if unknown. Used to choose - // ranges such that there will be no false sharing. - static size_t LineBytes(); - // Bytes per full vector. Used to compute loop steps. - static size_t VectorBytes(); - // Work granularity that avoids false sharing and partial vectors. - static size_t StepBytes(); // = HWY_MAX(LineBytes(), VectorBytes()) - // Granularity like `StepBytes()`, but when NUMA may be involved. - static size_t QuantumBytes(); - // Upper bound on `QuantumBytes()`, for stack allocations. - static constexpr size_t MaxQuantumBytes() { return 4096; } - static size_t QuantumSteps(); // = QuantumBytes() / StepBytes() - - // L1 and L2 are typically per core. - static size_t L1Bytes(); - static size_t L2Bytes(); - // Clusters often share an L3. We return the total size per package. - static size_t L3Bytes(); - - // Returns pointer aligned to `QuantumBytes()`. - template - static AlignedPtr Alloc(size_t num) { - constexpr size_t kSize = sizeof(T); - constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; - constexpr size_t kBits = hwy::detail::ShiftCount(kSize); - static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); - const size_t bytes = kIsPow2 ? num << kBits : num * kSize; - // Fail if the `bytes = num * kSize` computation overflowed. - const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; - if (check != num) return AlignedPtr(); - - PtrAndDeleter pd = AllocBytes(bytes); - return AlignedPtr(static_cast(pd.p), pd.deleter); - } - - // Same as Alloc, but calls constructor(s) with `args`. - template - static AlignedClassPtr AllocClasses(size_t num, Args&&... args) { - constexpr size_t kSize = sizeof(T); - constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0; - constexpr size_t kBits = hwy::detail::ShiftCount(kSize); - static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug"); - const size_t bytes = kIsPow2 ? num << kBits : num * kSize; - // Fail if the `bytes = num * kSize` computation overflowed. - const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize; - if (check != num) return AlignedClassPtr(); - - PtrAndDeleter pd = AllocBytes(bytes); - T* p = static_cast(pd.p); - for (size_t i = 0; i < num; ++i) { - new (p + i) T(std::forward(args)...); - } - return AlignedClassPtr(p, DeleterDtor(num, pd.deleter)); - } - - // Returns whether `BindMemory` can/should be called, i.e. we have page-level - // control over memory placement and multiple packages and NUMA nodes. - static bool ShouldBind(); - - // Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is - // typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`. - // Writes zeros to SOME of the memory. Only call if `ShouldBind()`. - // `p` and `bytes` must be multiples of `QuantumBytes()`. - static bool BindMemory(void* p, size_t bytes, size_t node); - - private: - // Type-erased so this can be implemented in allocator.cc. - struct PtrAndDeleter { - void* p; - DeleterFree deleter; - }; - static PtrAndDeleter AllocBytes(size_t bytes); -}; - -// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets" -// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is -// typically 4KiB. To avoid remote accesses, we would thus pad each row to that, -// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able -// to prevent that by pulling rows forward by a cyclic offset, which is still a -// multiple of the cache line size. This requires an additional -// `Allocator::QuantumBytes()` of padding after also rounding up to that. -template -constexpr size_t StrideForCyclicOffsets(size_t cols) { - const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T); - return hwy::RoundUpTo(cols, quantum) + quantum; -} - -// Owns dynamically-allocated aligned memory for a batch of row vectors. -// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns -// the memory. -template -class RowVectorBatch { - public: - // Default ctor for Activations ctor. - RowVectorBatch() = default; - // Main ctor, called from Activations::Allocate. If `stride` = 0, the default, - // we default to tightly packed rows (`stride = cols`). - // WARNING: not all call sites support `stride` != cols. - // TODO: once they do, remove stride and behave like AllocateAlignedRows here. - RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) { - if (stride == 0) { - stride_ = extents_.cols; - } else { - HWY_ASSERT(stride >= extents_.cols); - stride_ = stride; - } - // Allow binding the entire matrix. - const size_t padded = hwy::RoundUpTo(extents_.rows * stride_, - Allocator::QuantumBytes() / sizeof(T)); - mem_ = Allocator::Alloc(padded); - } - - // Move-only - RowVectorBatch(RowVectorBatch&) noexcept = delete; - RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; - RowVectorBatch(RowVectorBatch&&) noexcept = default; - RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; - - size_t BatchSize() const { return extents_.rows; } - size_t Cols() const { return extents_.cols; } - size_t Stride() const { return stride_; } - Extents2D Extents() const { return extents_; } - - // Returns the given row vector of length `Cols()`. - T* Batch(size_t batch_idx) { - HWY_DASSERT(batch_idx < BatchSize()); - return mem_.get() + batch_idx * stride_; - } - const T* Batch(size_t batch_idx) const { - HWY_DASSERT(batch_idx < BatchSize()); - return mem_.get() + batch_idx * stride_; - } - - // For MatMul or other operations that process the entire batch at once. - // TODO: remove once we only use Mat. - T* All() { return mem_.get(); } - const T* Const() const { return mem_.get(); } - size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); } - - private: - AlignedPtr mem_; - Extents2D extents_; - size_t stride_; -}; - -// Returns `num` rounded up to an odd number of cache lines. This is used to -// compute strides. An odd number of cache lines prevents 2K aliasing and is -// coprime with the cache associativity, which reduces conflict misses. -template -static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) { - HWY_DASSERT(line_bytes >= 32); - HWY_DASSERT(line_bytes % sizeof(T) == 0); - const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes); - const size_t padded_num = (lines | 1) * line_bytes / sizeof(T); - HWY_DASSERT(padded_num >= num); - return padded_num; -} - -template -RowVectorBatch AllocateAlignedRows(Extents2D extents) { - return RowVectorBatch(extents, StrideForCyclicOffsets(extents.cols)); -} - -// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because -// it is always float and does not support compressed T, but does support an -// arbitrary stride >= cols. -#pragma pack(push, 1) // power of two size -template -class RowPtr { - public: - RowPtr() = default; // for `MMPtrs`. - RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride) - : row0_(row0), - stride_(stride), - step_(static_cast(Allocator::StepBytes())), - cols_(static_cast(cols)), - row_mask_(Allocator::QuantumSteps() - 1) { - HWY_DASSERT(stride >= cols); - HWY_DASSERT(row_mask_ != ~size_t{0}); - if constexpr (HWY_IS_DEBUG_BUILD) { - if (stride < StrideForCyclicOffsets(cols)) { - static bool once; - if (!once) { - once = true; - HWY_WARN( - "Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), " - "T=%zu; this forces us to disable cyclic offsets.", - stride, cols, sizeof(T)); - } - row_mask_ = 0; - } - } - } - RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {} - - T* HWY_RESTRICT Row(size_t r) const { - // How much of the previous row's padding to consume. - const size_t pad_bytes = (r & row_mask_) * step_; - HWY_DASSERT(pad_bytes < Allocator::QuantumBytes()); - return row0_ + stride_ * r - pad_bytes; - } - size_t Cols() const { return cols_; } - - size_t Stride() const { return stride_; } - void SetStride(size_t stride) { - HWY_DASSERT(stride >= Cols()); - stride_ = stride; - // The caller might not have padded enough, so disable the padding in Row(). - // Rows will now be exactly `stride` elements apart. This is used when - // writing to the KV cache via MatMul. - row_mask_ = 0; - } - - // Returns 2D subrange whose top-left is `r, c` and width is `cols`. - RowPtr View(size_t r, size_t c, size_t cols) const { - HWY_DASSERT(c < cols_); - HWY_DASSERT(cols <= cols_ - c); - return RowPtr(Row(r) + c, cols, stride_); - } - - private: - T* HWY_RESTRICT row0_; - size_t stride_; - uint32_t step_; // Copy from Allocator::LineBytes() to improve locality. - uint32_t cols_; - size_t row_mask_; -}; -#pragma pack(pop) - -using RowPtrBF = RowPtr; -using RowPtrF = RowPtr; -using RowPtrD = RowPtr; - -// For C argument to MatMul. -template -RowPtr RowPtrFromBatch(RowVectorBatch& row_vectors) { - return RowPtr(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride()); -} - // Custom deleter for types without a dtor, but where the deallocation requires // state, e.g. a lambda with *by-value* capture. class DeleterFunc2 { @@ -420,15 +121,22 @@ class Allocator2 { size_t TotalMiB() const { return total_mib_; } size_t FreeMiB() const; - // Returns pointer aligned to `QuantumBytes()`. + // Returns byte pointer aligned to `QuantumBytes()`, without calling + // constructors nor destructors on deletion. Type-erased so this can be + // implemented in `allocator.cc` and called by `MatOwner`. + AlignedPtr2 AllocBytes(size_t bytes) const; + + // Returns pointer aligned to `QuantumBytes()`, without calling constructors + // nor destructors on deletion. template AlignedPtr2 Alloc(size_t num) const { const size_t bytes = num * sizeof(T); // Fail if the `bytes = num * sizeof(T)` computation overflowed. HWY_ASSERT(bytes / sizeof(T) == num); - PtrAndDeleter pd = AllocBytes(bytes); - return AlignedPtr2(static_cast(pd.p), pd.deleter); + AlignedPtr2 p8 = AllocBytes(bytes); + return AlignedPtr2(HWY_RCAST_ALIGNED(T*, p8.release()), + p8.get_deleter()); } // Same as Alloc, but calls constructor(s) with `args` and the deleter will @@ -439,12 +147,12 @@ class Allocator2 { // Fail if the `bytes = num * sizeof(T)` computation overflowed. HWY_ASSERT(bytes / sizeof(T) == num); - PtrAndDeleter pd = AllocBytes(bytes); - T* p = static_cast(pd.p); + AlignedPtr2 p8 = AllocBytes(bytes); + T* p = HWY_RCAST_ALIGNED(T*, p8.release()); for (size_t i = 0; i < num; ++i) { new (p + i) T(std::forward(args)...); } - return AlignedClassPtr2(p, DeleterDtor2(num, pd.deleter)); + return AlignedClassPtr2(p, DeleterDtor2(num, p8.get_deleter())); } // Returns whether `BindMemory` can/should be called, i.e. we have page-level @@ -458,13 +166,6 @@ class Allocator2 { bool BindMemory(void* p, size_t bytes, size_t node) const; private: - // Type-erased so this can be implemented in allocator.cc. - struct PtrAndDeleter { - void* p; - DeleterFunc2 deleter; - }; - PtrAndDeleter AllocBytes(size_t bytes) const; - size_t line_bytes_; size_t vector_bytes_; size_t step_bytes_; diff --git a/util/args.h b/util/args.h index ab496ae..96ac0b9 100644 --- a/util/args.h +++ b/util/args.h @@ -23,7 +23,7 @@ #include // std::transform #include -#include "compression/io.h" +#include "compression/io.h" // Path #include "util/basics.h" // Tristate #include "hwy/base.h" // HWY_ABORT diff --git a/util/mat.cc b/util/mat.cc new file mode 100644 index 0000000..677e928 --- /dev/null +++ b/util/mat.cc @@ -0,0 +1,100 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "util/mat.h" + +#include +#include + +#include "util/threading_context.h" +#include "hwy/base.h" +#include "hwy/per_target.h" // VectorBytes +#include "hwy/profiler.h" + +namespace gcpp { + +void CopyMat(const MatPtr& from, MatPtr& to) { + PROFILER_FUNC; + HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols()); + HWY_ASSERT(to.GetType() == from.GetType()); + if (to.IsPacked() && from.IsPacked()) { + HWY_ASSERT(to.PackedBytes() == from.PackedBytes()); + hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes()); + return; + } + const size_t row_bytes = to.Cols() * to.ElementBytes(); + for (size_t r = 0; r < to.Rows(); ++r) { + const uint8_t* from_row = from.RowT(r); + uint8_t* to_row = to.RowT(r); + hwy::CopyBytes(from_row, to_row, row_bytes); + } +} + +void ZeroInit(MatPtr& mat) { + PROFILER_FUNC; + HWY_ASSERT_M(mat.HasPtr(), mat.Name()); + if (mat.IsPacked()) { + hwy::ZeroBytes(mat.Packed(), mat.PackedBytes()); + return; + } + const size_t row_bytes = mat.Cols() * mat.ElementBytes(); + for (size_t r = 0; r < mat.Rows(); ++r) { + hwy::ZeroBytes(mat.RowT(r), row_bytes); + } +} + +// Returns `num` rounded up to an odd number of cache lines. This would also +// prevent 4K aliasing and is coprime with the cache associativity, which +// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`. +static size_t RoundUpToOddLines(size_t num, size_t line_bytes, + size_t element_bytes) { + HWY_DASSERT(line_bytes >= 32); + HWY_DASSERT(line_bytes % element_bytes == 0); + const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes); + const size_t padded_num = (lines | 1) * line_bytes / element_bytes; + HWY_DASSERT(padded_num >= num); + return padded_num; +} + +static size_t Stride(const Allocator2& allocator, const MatPtr& mat, + MatPadding padding) { + switch (padding) { + case MatPadding::kPacked: + default: + return mat.Cols(); + case MatPadding::kOdd: + return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(), + mat.ElementBytes()); + case MatPadding::kCyclic: + return StrideForCyclicOffsets( + mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes()); + } +} + +void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) { + const Allocator2& allocator = ThreadingContext2::Get().allocator; + const size_t stride = Stride(allocator, mat, padding); + const size_t num = mat.Rows() * stride; + // `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding` + // might not be enough, hence add extra. `MatT` is at least one byte, which + // is half of BF16, hence adding `VectorBytes` *elements* is enough. + const size_t bytes = (num + hwy::VectorBytes()) * mat.ElementBytes(); + // Allow binding the entire matrix. + const size_t padded_bytes = + hwy::RoundUpTo(bytes, allocator.QuantumBytes() / mat.ElementBytes()); + storage_ = allocator.AllocBytes(padded_bytes); + mat.SetPtr(storage_.get(), stride); +} +} // namespace gcpp diff --git a/util/mat.h b/util/mat.h new file mode 100644 index 0000000..3d7057c --- /dev/null +++ b/util/mat.h @@ -0,0 +1,532 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Tensor metadata and in-memory representation. +#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ +#define THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ + +#include +#include + +#include +#include + +// IWYU pragma: begin_exports +#include "compression/fields.h" +#include "compression/shared.h" // Type +#include "gemma/tensor_index.h" +#include "util/allocator.h" +#include "util/basics.h" // Extents2D +// IWYU pragma: end_exports +#include "hwy/base.h" + +namespace gcpp { + +// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector +// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class +// to store hetereogeneous tensor references in a vector. +// +// Copyable, (de)serializable via `fields.h` for `model_store.h`. +class MatPtr : public IFields { + public: + MatPtr() = default; + // `name`: see `SetName`. Note that `stride` is initially `cols` and only + // differs after deserializing, or calling `SetPtr`. + MatPtr(const char* name, Type type, Extents2D extents) + : rows_(static_cast(extents.rows)), + cols_(static_cast(extents.cols)) { + SetName(name); + SetType(type); + SetPtr(nullptr, cols_); + } + + // Copying allowed because the metadata is small. + MatPtr(const MatPtr& other) = default; + MatPtr& operator=(const MatPtr& other) = default; + + virtual ~MatPtr() = default; + + // Only for use by ctor, `AllocateFor` and 'loading' memory-mapped tensors. + void SetPtr(void* ptr, size_t stride) { + HWY_ASSERT(stride >= Cols()); + ptr_ = ptr; + stride_ = static_cast(stride); + + // NUQ streams must not be padded because that would change the position of + // the group tables. + if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked()); + } + + bool HasPtr() const { return ptr_ != nullptr; } + + bool IsPacked() const { return stride_ == cols_; } + + const void* Packed() const { + HWY_DASSERT_M(IsPacked(), name_.c_str()); + return ptr_; + } + void* Packed() { + HWY_DASSERT_M(IsPacked(), name_.c_str()); + return ptr_; + } + + // Returns size in bytes for purposes of copying/initializing or I/O. Must + // only be called if `IsPacked`. + size_t PackedBytes() const { + HWY_DASSERT_M(IsPacked(), name_.c_str()); + // num_elements_ already includes the NUQ tables. + return num_elements_ * element_bytes_; + } + + // Works for any kind of padding. + template + T* MutableRowT(size_t row) const { + HWY_DASSERT(row < rows_); + return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; + } + template + T* RowT(size_t row) { + HWY_DASSERT(row < rows_); + return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_; + } + template + const T* RowT(size_t row) const { + HWY_DASSERT(row < rows_); + return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_; + } + + Type GetType() const { return type_; } + void SetType(Type type) { + type_ = type; + element_bytes_ = static_cast(hwy::DivCeil(TypeBits(type), 8)); + num_elements_ = static_cast(ComputeNumElements(type, Extents())); + } + + bool IsEmpty() const { return rows_ == 0 || cols_ == 0; } + size_t Rows() const { return rows_; } + size_t Cols() const { return cols_; } + Extents2D Extents() const { return Extents2D(rows_, cols_); } + + // Offset by which to advance pointers to the next row. + size_t Stride() const { return stride_; } + + // For use by `BlobStore`, `CopyMat` and `ZeroInit`. + size_t ElementBytes() const { return element_bytes_; } + + // Decoded elements should be multiplied by this to restore their original + // range. This is required because `SfpStream` can only encode a limited range + // of magnitudes. + float Scale() const { return scale_; } + void SetScale(float scale) { scale_ = scale; } + + // Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it + // be <= 16 bytes including prefixes/suffixes. The initial name set by the + // ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer + // suffix, and when loading, we call `SetName` with that. + const char* Name() const override { return name_.c_str(); } + void SetName(const char* name) { + name_ = name; + HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name); + } + + void VisitFields(IFieldsVisitor& visitor) override { + // Order determines the order of serialization and must not change. + visitor(name_); + visitor(type_); + visitor(element_bytes_); + visitor(num_elements_); + visitor(rows_); + visitor(cols_); + visitor(scale_); + visitor(stride_); + } + + protected: + // For initializing `num_elements_`: "elements" are how many objects we + // actually store in order to represent rows * cols values. For NUQ, this is + // greater because it includes additional per-group tables. This is the only + // place where we compute this fixup. Note that elements are independent of + // padding, which is anyway not supported for NUQ because `compress-inl.h` + // assumes a contiguous stream for its group indexing. + static size_t ComputeNumElements(Type type, Extents2D extents) { + const size_t num_elements = extents.Area(); + if (type == Type::kNUQ) { + // `CompressedArrayElements` is a wrapper function that has the same + // effect, but that requires a template argument, not `type`. + return NuqStream::PackedEnd(num_elements); + } + return num_elements; + } + + std::string name_; // See `SetName`. + Type type_; + + // Most members are u32 because that is the preferred type of fields.h. + + // Bytes per element. This is fully determined by `type_`, but stored here + // for convenience and backward compatibility. + uint32_t element_bytes_ = 0; + // Number of elements to store (including NUQ tables but not padding). + // This a function of `type_` and `Extents()` and stored for compatibility. + uint32_t num_elements_ = 0; + uint32_t rows_ = 0; + uint32_t cols_ = 0; + float scale_ = 1.0f; // multiplier for each value, for MatMul. + + // Non-owning pointer, must not be freed. The underlying memory must outlive + // this object. + void* ptr_ = nullptr; // not serialized + + // Offset by which to advance pointers to the next row, >= `cols_`. + uint32_t stride_; +}; + +// Non-type erased version of `MatPtr`. Use this when operating on the values. +template +class MatPtrT : public MatPtr { + public: + // Runtime-specified shape. + MatPtrT(const char* name, Extents2D extents) + : MatPtr(name, TypeEnum(), extents) {} + // Take shape from `TensorInfo` to avoid duplicating it in the caller. + MatPtrT(const char* name, const TensorInfo* tensor) + : MatPtrT(name, ExtentsFromInfo(tensor)) {} + // Find `TensorInfo` by name in `TensorIndex`. + MatPtrT(const char* name, const TensorIndex& tensor_index) + : MatPtrT(name, tensor_index.FindName(name)) {} + + // Copying allowed because the metadata is small. + MatPtrT(const MatPtr& other) : MatPtr(other) {} + MatPtrT& operator=(const MatPtr& other) { + MatPtr::operator=(other); + return *this; + } + MatPtrT(const MatPtrT& other) = default; + MatPtrT& operator=(const MatPtrT& other) = default; + + // Returns the entire tensor for use by `backprop/*`. Verifies layout is + // `kPacked`. Preferably call `Row` instead, which works for either layout. + MatT* Packed() { + HWY_DASSERT_M(IsPacked(), name_.c_str()); + return HWY_RCAST_ALIGNED(MatT*, ptr_); + } + const MatT* Packed() const { + HWY_DASSERT_M(IsPacked(), name_.c_str()); + return HWY_RCAST_ALIGNED(const MatT*, ptr_); + } + // As `Packed()`, plus checks the scale is 1.0 because callers will ignore it. + // This is typically used for `MatMul` bias vectors and norm weights. + const MatT* PackedScale1() const { + HWY_DASSERT(Scale() == 1.0f); + return Packed(); + } + + const MatT* Row(size_t row) const { return this->RowT(row); } + MatT* Row(size_t row) { return this->RowT(row); } + + // For `compress-inl.h` functions, which assume contiguous streams and thus + // require packed layout. + PackedSpan Span() const { + HWY_ASSERT(IsPacked()); + return MakeConstSpan(Row(0), num_elements_); + } + PackedSpan Span() { + HWY_ASSERT(IsPacked()); + return MakeSpan(Row(0), num_elements_); + } +}; + +// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT`, plus the +// optional `args`. +template +decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func, + Args&&... args) { + HWY_ASSERT(base != nullptr); + if (type == Type::kF32) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } else if (type == Type::kBF16) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } else if (type == Type::kSFP) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } else if (type == Type::kNUQ) { + return func(dynamic_cast*>(base), + std::forward(args)...); + } else { + HWY_ABORT("Type %d unknown.", static_cast(type)); + } +} + +void CopyMat(const MatPtr& from, MatPtr& to); +void ZeroInit(MatPtr& mat); + +template +void RandInit(MatPtrT& x, T stddev, std::mt19937& gen) { + std::normal_distribution dist(0.0, stddev); + for (size_t r = 0; r < x.Rows(); ++r) { + T* row = x.Row(r); + for (size_t c = 0; c < x.Cols(); ++c) { + row[c] = dist(gen); + } + } +} + +// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If +// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB. +// To avoid remote accesses, we would thus pad each row to that, which results +// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that +// by pulling rows forward by a cyclic offset, which is still a multiple of the +// cache line size. This requires an additional `Allocator2::QuantumBytes()` of +// padding after also rounding up to that, which considerably increases size for +// tall and skinny tensors. +static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) { + return hwy::RoundUpTo(cols, quantum) + quantum; +} +// Constexpr version (upper bound) for allocating storage in MatMul. +template +constexpr size_t MaxStrideForCyclicOffsets(size_t cols) { + constexpr size_t quantum = Allocator2::MaxQuantum(); + return hwy::RoundUpTo(cols, quantum) + quantum; +} + +// Our tensors are always row-major. This enum indicates how much (if any) +// padding comes after each row. +enum class MatPadding { + // None, stride == cols. `compress-inl.h` requires this layout because its + // interface assumes a continuous 1D array, without awareness of rows. Note + // that tensors which were written via `compress-inl.h` (i.e. most in + // `BlobStore`) are not padded, which also extends to memory-mapped tensors. + // However, `BlobStore` is able to insert padding via row-wise I/O when + // reading from disk via `Mode::kRead`. + // + // `backprop/*` also requires this layout because it indexes directly into + // the storage instead of calling `Row()`. + kPacked, + // Enough to round up to an odd number of cache lines, which can reduce + // cache conflict misses or 4K aliasing. + kOdd, + // Enough to enable the "cyclic offsets" optimization for `MatMul`. + kCyclic, +}; + +// Type-erased, allows storing `AlignedPtr2` for various T in the same +// vector. +class MatOwner { + public: + MatOwner() = default; + // Allow move for `MatStorageT`. + MatOwner(MatOwner&&) = default; + MatOwner& operator=(MatOwner&&) = default; + + // Allocates the type/extents indicated by `mat` and sets its pointer. + void AllocateFor(MatPtr& mat, MatPadding padding); + + private: + AlignedPtr2 storage_; +}; + +// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and +// tests to allocate and access tensors of a known type. By contrast, the +// heterogeneous model weights are owned by vectors of `MatOwner`. +template +class MatStorageT : public MatPtrT { + public: + MatStorageT(const char* name, Extents2D extents, MatPadding padding) + : MatPtrT(name, extents) { + owner_.AllocateFor(*this, padding); + } + ~MatStorageT() = default; + + // Allow move for backprop/activations. + MatStorageT(MatStorageT&&) = default; + MatStorageT& operator=(MatStorageT&&) = default; + + private: + MatOwner owner_; +}; + +// Helper factory function for use by `backprop/` to avoid specifying the +// `MatPadding` argument everywhere. +template +MatStorageT MakePacked(const char* name, size_t rows, size_t cols) { + return MatStorageT(name, Extents2D(rows, cols), MatPadding::kPacked); +} + +// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with +// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets. +#pragma pack(push, 1) // power of two size +template +class RowPtr { + public: + RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols, + size_t stride) + : row0_(row0), + stride_(stride), + row_mask_( + static_cast(allocator.QuantumStepMask() & 0xFFFFFFFFu)), + cols_(static_cast(cols)), + step_bytes_(static_cast(allocator.StepBytes())), + quantum_bytes_(allocator.QuantumBytes()) { + HWY_DASSERT(stride >= cols); + HWY_DASSERT(row_mask_ != ~uint32_t{0}); + if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) { + row_mask_ = 0; + if constexpr (HWY_IS_DEBUG_BUILD) { + static bool once; + if (stride != cols && !once) { + once = true; + HWY_WARN( + "Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), " + "T=%zu; this forces us to disable cyclic offsets.", + stride, cols, sizeof(T)); + } + } + } + } + + RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols) + : RowPtr(allocator, row0, cols, cols) {} + + T* HWY_RESTRICT Row(size_t r) const { + // How much of the previous row's padding to consume. + const size_t pad_bytes = (r & row_mask_) * step_bytes_; + HWY_DASSERT(pad_bytes < static_cast(quantum_bytes_)); + return row0_ + stride_ * r - pad_bytes; + } + size_t Cols() const { return static_cast(cols_); } + + size_t Stride() const { return stride_; } + void SetStride(size_t stride) { + HWY_DASSERT(stride >= Cols()); + stride_ = stride; + // The caller might not have padded enough, so disable the padding in Row(). + // Rows will now be exactly `stride` elements apart. This is used when + // writing to the KV cache via MatMul. + row_mask_ = 0; + } + + // Returns 2D subrange whose top-left is `r, c` and width is `cols`. + RowPtr View(size_t r, size_t c, size_t cols) const { + HWY_DASSERT(c < Cols()); + HWY_DASSERT(cols <= Cols() - c); + return RowPtr(Row(r) + c, cols, stride_, row_mask_, step_bytes_, + quantum_bytes_); + } + + private: + // For `View()`. + RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask, + uint32_t step_bytes, uint32_t quantum_bytes) + : row0_(new_row0), + stride_(stride), + row_mask_(row_mask), + cols_(new_cols), + step_bytes_(step_bytes), + quantum_bytes_(quantum_bytes) {} + + T* HWY_RESTRICT row0_; + size_t stride_; + uint32_t row_mask_; + uint32_t cols_; + uint32_t step_bytes_; + uint32_t quantum_bytes_; +}; +#pragma pack(pop) + +using RowPtrBF = RowPtr; +using RowPtrF = RowPtr; +using RowPtrD = RowPtr; + +// Owns dynamically-allocated aligned memory for a batch of row vectors. +// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns +// the memory. Unlike `MatPtr`, this lacks metadata. +// TODO: replace with `MatStorageT`. +template +class RowVectorBatch { + public: + // Default ctor for Activations ctor. + RowVectorBatch() = default; + // Main ctor, called from Activations::Allocate. If `stride` = 0, the default, + // we default to tightly packed rows (`stride = cols`). + // WARNING: not all call sites support `stride` != cols. + // TODO: once they do, remove stride and behave like AllocateAlignedRows here. + RowVectorBatch(const Allocator2& allocator, Extents2D extents, + size_t stride = 0) + : extents_(extents) { + if (stride == 0) { + stride_ = extents_.cols; + } else { + HWY_ASSERT(stride >= extents_.cols); + stride_ = stride; + } + // Allow binding the entire matrix. + const size_t padded = hwy::RoundUpTo(extents_.rows * stride_, + allocator.QuantumBytes() / sizeof(T)); + mem_ = allocator.Alloc(padded); + } + + // Move-only + RowVectorBatch(RowVectorBatch&) noexcept = delete; + RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete; + RowVectorBatch(RowVectorBatch&&) noexcept = default; + RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default; + + size_t BatchSize() const { return extents_.rows; } + size_t Cols() const { return extents_.cols; } + size_t Stride() const { return stride_; } + Extents2D Extents() const { return extents_; } + + // Returns the given row vector of length `Cols()`. + T* Batch(size_t batch_idx) { + HWY_DASSERT(batch_idx < BatchSize()); + return mem_.get() + batch_idx * stride_; + } + const T* Batch(size_t batch_idx) const { + HWY_DASSERT(batch_idx < BatchSize()); + return mem_.get() + batch_idx * stride_; + } + + // For MatMul or other operations that process the entire batch at once. + // TODO: remove once we only use Mat. + T* All() { return mem_.get(); } + const T* Const() const { return mem_.get(); } + size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); } + + private: + AlignedPtr2 mem_; + Extents2D extents_; + size_t stride_; +}; + +template +RowPtr RowPtrFromBatch(const Allocator2& allocator, + RowVectorBatch& row_vectors) { + return RowPtr(allocator, row_vectors.All(), row_vectors.Cols(), + row_vectors.Stride()); +} + +template +RowVectorBatch AllocateAlignedRows(const Allocator2& allocator, + Extents2D extents) { + return RowVectorBatch( + allocator, extents, + StrideForCyclicOffsets(extents.cols, + allocator.QuantumBytes() / sizeof(T))); +} + +} // namespace gcpp +#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_ diff --git a/util/threading.cc b/util/threading.cc index c2f8bb7..0ed3a3d 100644 --- a/util/threading.cc +++ b/util/threading.cc @@ -13,12 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "util/threading.h" +#include "util/threading.h" // NOT threading_context.. +// to ensure there is no deadlock. #include #include // std::sort #include +#include #include #include @@ -69,7 +71,7 @@ class Pinning { const int bytes_written = snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", pkg_idx, cluster_idx, static_cast(task)); - HWY_ASSERT(bytes_written < sizeof(buf)); + HWY_ASSERT(bytes_written < static_cast(sizeof(buf))); hwy::SetThreadName(buf, 0); // does not support varargs if (HWY_LIKELY(want_pin_)) { @@ -107,16 +109,16 @@ static Pinning& GetPinning() { return pinning; } -static PoolPtr MakePool(size_t num_workers, +static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers, std::optional node = std::nullopt) { // `ThreadPool` expects the number of threads to create, which is one less // than the number of workers, but avoid underflow if zero. const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1; - PoolPtr ptr = Allocator::AllocClasses(1, num_threads); + PoolPtr ptr = allocator.AllocClasses(1, num_threads); const size_t bytes = - hwy::RoundUpTo(sizeof(hwy::ThreadPool), Allocator::QuantumBytes()); - if (node.has_value() && Allocator::ShouldBind()) { - Allocator::BindMemory(ptr.get(), bytes, node.value()); + hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes()); + if (node.has_value() && allocator.ShouldBind()) { + allocator.BindMemory(ptr.get(), bytes, node.value()); } return ptr; } @@ -133,22 +135,22 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) { return max; } -NestedPools::NestedPools(const BoundedTopology& topology, size_t max_threads, +NestedPools::NestedPools(const BoundedTopology& topology, + const Allocator2& allocator, size_t max_threads, Tristate pin) { GetPinning().SetPolicy(pin); packages_.resize(topology.NumPackages()); - all_packages_ = MakePool(packages_.size()); + all_packages_ = MakePool(allocator, packages_.size()); const size_t max_workers_per_package = DivideMaxAcross(max_threads, packages_.size()); // Each worker in all_packages_, including the main thread, will be the - // calling thread of an all_clusters[0].Run, and hence pinned to one of the + // calling thread of an all_clusters->Run, and hence pinned to one of the // `cluster.lps` if `pin`. - all_packages_[0].Run( - 0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) { - HWY_ASSERT(pkg_idx == thread); // each thread has one task - packages_[pkg_idx] = - Package(topology, pkg_idx, max_workers_per_package); - }); + all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) { + HWY_ASSERT(pkg_idx == thread); // each thread has one task + packages_[pkg_idx] = + Package(topology, allocator, pkg_idx, max_workers_per_package); + }); all_pinned_ = GetPinning().AllPinned(&pin_string_); @@ -172,28 +174,29 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) { return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero); } -NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx, +NestedPools::Package::Package(const BoundedTopology& topology, + const Allocator2& allocator, size_t pkg_idx, size_t max_workers_per_package) { // Pre-allocate because elements are set concurrently. clusters_.resize(topology.NumClusters(pkg_idx)); const size_t max_workers_per_cluster = DivideMaxAcross(max_workers_per_package, clusters_.size()); - all_clusters_ = - MakePool(clusters_.size(), topology.GetCluster(pkg_idx, 0).Node()); + all_clusters_ = MakePool(allocator, clusters_.size(), + topology.GetCluster(pkg_idx, 0).Node()); // Parallel so we also pin the calling worker in `all_clusters` to // `cluster.lps`. - all_clusters_[0].Run( - 0, all_clusters_[0].NumWorkers(), [&](size_t cluster_idx, size_t thread) { + all_clusters_->Run( + 0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) { HWY_ASSERT(cluster_idx == thread); // each thread has one task const BoundedTopology::Cluster& cluster = topology.GetCluster(pkg_idx, cluster_idx); - clusters_[cluster_idx] = - MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster), - cluster.Node()); + clusters_[cluster_idx] = MakePool( + allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster), + cluster.Node()); // Pin workers AND the calling thread from `all_clusters`. GetPinning().MaybePin(pkg_idx, cluster_idx, cluster, - clusters_[cluster_idx][0]); + *clusters_[cluster_idx]); }); } diff --git a/util/threading.h b/util/threading.h index d7de410..d7def57 100644 --- a/util/threading.h +++ b/util/threading.h @@ -23,6 +23,7 @@ // IWYU pragma: begin_exports #include "util/allocator.h" +#include "util/args.h" #include "util/basics.h" // Tristate #include "util/topology.h" #include "hwy/base.h" // HWY_ASSERT @@ -37,7 +38,7 @@ namespace gcpp { // Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows // moving because it is a typedef to `std::unique_ptr`. -using PoolPtr = AlignedClassPtr; +using PoolPtr = AlignedClassPtr2; // Creates a hierarchy of thread pools according to `BoundedTopology`: one with // a thread per enabled package; for each of those, one with a thread per @@ -73,10 +74,8 @@ class NestedPools { // would cause huge slowdowns when spinning, the `BoundedSlice` arguments // only impose upper bounds on the number of detected packages and clusters // rather than defining the actual number of threads. - // - // Caller must have called `Allocator::Init` before this. - NestedPools(const BoundedTopology& topology, size_t max_threads = 0, - Tristate pin = Tristate::kDefault); + NestedPools(const BoundedTopology& topology, const Allocator2& allocator, + size_t max_threads = 0, Tristate pin = Tristate::kDefault); bool AllPinned() const { return all_pinned_; } @@ -103,7 +102,7 @@ class NestedPools { } size_t NumPackages() const { return packages_.size(); } - hwy::ThreadPool& AllPackages() { return all_packages_[0]; } + hwy::ThreadPool& AllPackages() { return *all_packages_; } hwy::ThreadPool& AllClusters(size_t pkg_idx) { HWY_DASSERT(pkg_idx < NumPackages()); return packages_[pkg_idx].AllClusters(); @@ -149,36 +148,36 @@ class NestedPools { class Package { public: Package() = default; // for vector - Package(const BoundedTopology& topology, size_t pkg_idx, - size_t max_workers_per_package); + Package(const BoundedTopology& topology, const Allocator2& allocator, + size_t pkg_idx, size_t max_workers_per_package); size_t NumClusters() const { return clusters_.size(); } size_t MaxWorkersPerCluster() const { size_t max_workers_per_cluster = 0; for (const PoolPtr& cluster : clusters_) { max_workers_per_cluster = - HWY_MAX(max_workers_per_cluster, cluster[0].NumWorkers()); + HWY_MAX(max_workers_per_cluster, cluster->NumWorkers()); } return max_workers_per_cluster; } size_t TotalWorkers() const { size_t total_workers = 0; for (const PoolPtr& cluster : clusters_) { - total_workers += cluster[0].NumWorkers(); + total_workers += cluster->NumWorkers(); } return total_workers; } - hwy::ThreadPool& AllClusters() { return all_clusters_[0]; } + hwy::ThreadPool& AllClusters() { return *all_clusters_; } hwy::ThreadPool& Cluster(size_t cluster_idx) { HWY_DASSERT(cluster_idx < clusters_.size()); - return clusters_[cluster_idx][0]; + return *clusters_[cluster_idx]; } void SetWaitMode(hwy::PoolWaitMode wait_mode) { - all_clusters_[0].SetWaitMode(wait_mode); + all_clusters_->SetWaitMode(wait_mode); for (PoolPtr& cluster : clusters_) { - cluster[0].SetWaitMode(wait_mode); + cluster->SetWaitMode(wait_mode); } } @@ -188,7 +187,7 @@ class NestedPools { }; // Package void SetWaitMode(hwy::PoolWaitMode wait_mode) { - all_packages_[0].SetWaitMode(wait_mode); + all_packages_->SetWaitMode(wait_mode); for (Package& package : packages_) { package.SetWaitMode(wait_mode); } diff --git a/util/threading_context.cc b/util/threading_context.cc index 9065335..c15e194 100644 --- a/util/threading_context.cc +++ b/util/threading_context.cc @@ -33,6 +33,13 @@ static std::mutex s_ctx_mutex; s_ctx_mutex.unlock(); } +/*static*/ bool ThreadingContext2::IsInitialized() { + s_ctx_mutex.lock(); + const bool initialized = !!s_ctx; + s_ctx_mutex.unlock(); + return initialized; +} + /*static*/ ThreadingContext2& ThreadingContext2::Get() { // We do not bother with double-checked locking because it requires an // atomic pointer, but we prefer to use unique_ptr for simplicity. Also, diff --git a/util/threading_context.h b/util/threading_context.h index 7430f16..a59dcdd 100644 --- a/util/threading_context.h +++ b/util/threading_context.h @@ -98,6 +98,10 @@ class ThreadingContext2 { // is expected to be called early in the program, before threading starts. static void SetArgs(const ThreadingArgs& args); + // Returns whether `Get()` has already been called, typically used to avoid + // calling `SetArgs` after that, because it would assert. + static bool IsInitialized(); + // Returns a reference to the singleton after initializing it if necessary. // When initializing, uses the args passed to `SetArgs`, or defaults. // diff --git a/util/threading_test.cc b/util/threading_test.cc index e7fe021..d99e53b 100644 --- a/util/threading_test.cc +++ b/util/threading_test.cc @@ -13,8 +13,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "util/threading.h" - #include #include @@ -22,9 +20,9 @@ #include "gmock/gmock.h" #include "gtest/gtest.h" -#include "util/allocator.h" #include "util/basics.h" -#include "hwy/aligned_allocator.h" +#include "util/threading_context.h" +#include "hwy/aligned_allocator.h" // Span #include "hwy/auto_tune.h" #include "hwy/base.h" // HWY_ASSERT #include "hwy/contrib/thread_pool/thread_pool.h" @@ -385,9 +383,7 @@ TEST(ThreadingTest, BenchJoin) { } }; - BoundedTopology topology; - Allocator::Init(topology, true); - NestedPools pools(topology); + NestedPools& pools = ThreadingContext2::Get().pools; // Use last package because the main thread has been pinned to it. const size_t pkg_idx = pools.NumPackages() - 1;