Major refactor of allocator/args:

use new ThreadingContext2 instead of monostate/init in each frontend
Add ThreadingArgs(replaces AppArgs)

backprop: use Packed() accessor and MakePacked factory and row-based access to allow for stride
compress_weights: remove, moving to py-only exporter instead

Move MatPtr to mat.h and revise interface:
- Generic MatOwner
- rename accessors to Packed*
- support stride/row accessors, fix RowPtr stride

Add TypeBits(Type)
Move GenerateMat to test_util-inl for sharing between matmul test/bench
Move internal init to gemma.cc to avoid duplication
Rename GemmaEnv model_ to gemma_ for disambiguating vs upcoming ModelStorage
Remove --compressed_weights, use --weights instead.
tensor_index: add ExtentsFromInfo and TensorIndexLLM/Img
Allocator: use normal unique_ptr for AllocBytes so users can call directly
threading: use -> because AlignedPtr no longer assumes arrays
PiperOrigin-RevId: 745918637
This commit is contained in:
Jan Wassenberg 2025-04-10 01:28:16 -07:00 committed by Copybara-Service
parent bef91a3f03
commit 8532da47f7
75 changed files with 2387 additions and 2768 deletions

View File

@ -19,7 +19,10 @@ license(
# Dual-licensed Apache 2 and 3-clause BSD.
licenses(["notice"])
exports_files(["LICENSE"])
exports_files([
"LICENSE",
".github/workflows/build.yml",
])
cc_library(
name = "basics",
@ -29,6 +32,16 @@ cc_library(
],
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
":basics",
"//compression:io", # Path
"@highway//:hwy",
],
)
# Split from :threading to break a circular dependency with :allocator.
cc_library(
name = "topology",
@ -59,6 +72,7 @@ cc_library(
hdrs = ["util/threading.h"],
deps = [
":allocator",
":args",
":basics",
":topology",
# Placeholder for container detection, do not remove
@ -68,14 +82,26 @@ cc_library(
],
)
cc_library(
name = "threading_context",
srcs = ["util/threading_context.cc"],
hdrs = ["util/threading_context.h"],
deps = [
":allocator",
":args",
":basics",
":threading",
":topology",
],
)
cc_test(
name = "threading_test",
srcs = ["util/threading_test.cc"],
deps = [
":allocator",
":basics",
":threading",
"@googletest//:gtest_main",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:auto_tune",
"@highway//:hwy",
"@highway//:hwy_test_util",
@ -97,6 +123,65 @@ cc_library(
],
)
cc_library(
name = "common",
srcs = [
"gemma/common.cc",
"gemma/configs.cc",
"gemma/tensor_index.cc",
],
hdrs = [
"gemma/common.h",
"gemma/configs.h",
"gemma/tensor_index.h",
],
deps = [
":basics",
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
],
)
cc_test(
name = "configs_test",
srcs = ["gemma/configs_test.cc"],
deps = [
":common",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
],
)
cc_test(
name = "tensor_index_test",
srcs = ["gemma/tensor_index_test.cc"],
deps = [
":basics",
":common",
":weights",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
],
)
cc_library(
name = "mat",
srcs = ["util/mat.cc"],
hdrs = ["util/mat.h"],
deps = [
":allocator",
":basics",
":common",
":threading_context",
"//compression:fields",
"//compression:sfp",
"@highway//:hwy",
"@highway//:profiler",
],
)
# For building all tests in one command, so we can test several.
test_suite(
name = "ops_tests",
@ -123,8 +208,9 @@ cc_library(
deps = [
":allocator",
":basics",
":mat",
":threading",
":topology",
":threading_context",
"//compression:compress",
"@highway//:algo",
"@highway//:bit_set",
@ -148,10 +234,9 @@ cc_test(
tags = ["ops_tests"],
deps = [
":allocator",
":app",
":ops",
":test_util",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
@ -174,13 +259,13 @@ cc_test(
tags = ["ops_tests"],
deps = [
":allocator",
":app",
":basics",
":common",
":mat",
":ops",
":test_util",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark", #buildcleaner: keep
@ -196,6 +281,7 @@ cc_test(
# for test_suite.
tags = ["ops_tests"],
deps = [
":mat",
":ops",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
@ -214,12 +300,13 @@ cc_test(
# for test_suite.
tags = ["ops_tests"],
deps = [
":allocator",
":basics",
":mat",
":ops",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
@ -238,12 +325,12 @@ cc_test(
"ops_tests", # for test_suite.
],
deps = [
":allocator",
":basics",
":ops",
":threading",
":threading_context",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:compress",
"//compression:test_util",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
@ -252,55 +339,13 @@ cc_test(
],
)
cc_library(
name = "common",
srcs = [
"gemma/common.cc",
"gemma/configs.cc",
"gemma/tensor_index.cc",
],
hdrs = [
"gemma/common.h",
"gemma/configs.h",
"gemma/tensor_index.h",
],
deps = [
":basics",
"//compression:fields",
"//compression:sfp",
"@highway//:hwy", # base.h
],
)
cc_test(
name = "configs_test",
srcs = ["gemma/configs_test.cc"],
deps = [
":common",
"@googletest//:gtest_main",
"@highway//:hwy",
],
)
cc_test(
name = "tensor_index_test",
srcs = ["gemma/tensor_index_test.cc"],
deps = [
":basics",
":common",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@highway//:hwy",
],
)
cc_library(
name = "weights",
srcs = ["gemma/weights.cc"],
hdrs = ["gemma/weights.h"],
deps = [
":common",
":mat",
"//compression:blob_store",
"//compression:compress",
"//compression:io",
@ -361,16 +406,17 @@ cc_library(
":basics",
":common",
":ops",
":mat",
":tokenizer",
":kv_cache",
":weights",
":threading",
"//compression:compress",
":threading_context",
# Placeholder for internal dep, do not remove.,
"//compression:io",
"//compression:sfp",
"//paligemma:image",
"@highway//:hwy",
"@highway//:bit_set",
"@highway//:nanobenchmark", # timer
"@highway//:profiler",
"@highway//:thread_pool",
@ -390,25 +436,14 @@ cc_library(
)
cc_library(
name = "args",
hdrs = ["util/args.h"],
deps = [
":basics",
"//compression:io",
"@highway//:hwy",
],
)
cc_library(
name = "app",
hdrs = ["util/app.h"],
name = "gemma_args",
hdrs = ["gemma/gemma_args.h"],
deps = [
":args",
":basics",
":common",
":gemma_lib",
":ops",
":threading",
"//compression:io",
"//compression:sfp",
"@highway//:hwy",
@ -420,20 +455,15 @@ cc_library(
srcs = ["evals/benchmark_helper.cc"],
hdrs = ["evals/benchmark_helper.h"],
deps = [
":app",
":args",
":common",
":cross_entropy",
":gemma_args",
":gemma_lib",
":kv_cache",
":ops",
":threading",
# Placeholder for internal dep, do not remove.,
":threading_context",
"@google_benchmark//:benchmark",
"//compression:compress",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:topology",
],
)
@ -451,7 +481,7 @@ cc_test(
":benchmark_helper",
":common",
":gemma_lib",
"@googletest//:gtest_main",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
],
@ -470,8 +500,7 @@ cc_test(
":benchmark_helper",
":common",
":gemma_lib",
":tokenizer",
"@googletest//:gtest_main",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
],
@ -481,14 +510,13 @@ cc_binary(
name = "gemma",
srcs = ["gemma/run.cc"],
deps = [
":app",
":args",
":benchmark_helper",
":common",
":gemma_args",
":gemma_lib",
":ops",
":threading",
# Placeholder for internal dep, do not remove.,
":threading_context",
"//compression:sfp",
"//paligemma:image",
"@highway//:hwy",
@ -594,10 +622,10 @@ cc_library(
deps = [
":allocator",
":common",
":mat",
":ops",
":prompt",
":weights",
"//compression:compress",
"@highway//:dot",
"@highway//:hwy", # base.h
"@highway//:thread_pool",
@ -614,9 +642,9 @@ cc_library(
],
deps = [
":common",
":mat",
":prompt",
":weights",
"//compression:compress",
"@highway//:hwy",
],
)
@ -631,11 +659,11 @@ cc_test(
deps = [
":backprop_scalar",
":common",
":mat",
":prompt",
":sampler",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:thread_pool",
],
)
@ -652,17 +680,16 @@ cc_test(
"mem": "28g",
},
deps = [
":allocator",
":backprop",
":backprop_scalar",
":common",
":mat",
":ops",
":prompt",
":sampler",
":threading",
":threading_context",
":weights",
"@googletest//:gtest_main",
"//compression:compress",
"@googletest//:gtest_main", # buildcleaner: keep
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
@ -676,6 +703,7 @@ cc_library(
deps = [
":allocator",
":common",
":mat",
":weights",
"//compression:compress",
"@highway//:hwy",
@ -685,9 +713,7 @@ cc_library(
cc_test(
name = "optimize_test",
srcs = [
"backprop/optimize_test.cc",
],
srcs = ["backprop/optimize_test.cc"],
exec_properties = {
# Avoid linker OOMs when building with sanitizer instrumentation.
"mem": "28g",
@ -704,7 +730,7 @@ cc_test(
":sampler",
":threading",
":weights",
"@googletest//:gtest_main",
"@googletest//:gtest_main", # buildcleaner: keep
"//compression:sfp",
"@highway//:thread_pool",
],

View File

@ -75,6 +75,7 @@ set(SOURCES
gemma/common.h
gemma/configs.cc
gemma/configs.h
gemma/gemma_args.h
gemma/gemma-inl.h
gemma/gemma.cc
gemma/gemma.h
@ -102,15 +103,17 @@ set(SOURCES
paligemma/image.h
util/allocator.cc
util/allocator.h
util/app.h
util/args.h
util/basics.h
util/mat.cc
util/mat.h
util/test_util.h
util/threading.cc
util/threading.h
util/threading_context.cc
util/threading_context.h
util/topology.cc
util/topology.h
)
)
if(NOT CMAKE_BUILD_TYPE)
set(CMAKE_BUILD_TYPE "Release")
@ -197,8 +200,5 @@ endif() # GEMMA_ENABLE_TESTS
## Tools
add_executable(compress_weights compression/compress_weights.cc)
target_link_libraries(compress_weights libgemma hwy hwy_contrib)
add_executable(migrate_weights compression/migrate_weights.cc)
target_link_libraries(migrate_weights libgemma hwy hwy_contrib)

View File

@ -20,24 +20,30 @@
#include <vector>
#include "compression/compress.h" // MatStorageT
#include "gemma/configs.h" // ModelConfig
#include "gemma/configs.h" // ModelConfig
#include "util/mat.h" // MatStorageT
namespace gcpp {
template <typename T>
struct ForwardLayer {
ForwardLayer(const LayerConfig& config, size_t seq_len)
: input("input", seq_len, config.model_dim),
pre_att_rms_out("pre_att_rms_out", seq_len, config.model_dim),
qkv("qkv", seq_len * (config.heads + 2), config.qkv_dim),
att("att", seq_len * config.heads, seq_len),
att_out("att_out", seq_len * config.heads, config.qkv_dim),
att_post1("att_post1", seq_len, config.model_dim),
attention_out("attention_out", seq_len, config.model_dim),
bf_pre_ffw_rms_out("bf_pre_ffw_rms_out", seq_len, config.model_dim),
ffw_hidden("ffw_hidden", seq_len, config.ff_hidden_dim * 2),
ffw_hidden_gated("ffw_hidden_gated", seq_len, config.ff_hidden_dim),
: input(MakePacked<T>("input", seq_len, config.model_dim)),
pre_att_rms_out(
MakePacked<T>("pre_att_rms_out", seq_len, config.model_dim)),
qkv(MakePacked<T>("qkv", seq_len * (config.heads + 2), config.qkv_dim)),
att(MakePacked<T>("att", seq_len * config.heads, seq_len)),
att_out(
MakePacked<T>("att_out", seq_len * config.heads, config.qkv_dim)),
att_post1(MakePacked<T>("att_post1", seq_len, config.model_dim)),
attention_out(
MakePacked<T>("attention_out", seq_len, config.model_dim)),
bf_pre_ffw_rms_out(
MakePacked<T>("bf_preFF_rms_out", seq_len, config.model_dim)),
ffw_hidden(
MakePacked<T>("ffw_hidden", seq_len, config.ff_hidden_dim * 2)),
ffw_hidden_gated(
MakePacked<T>("ffw_hidden_gated", seq_len, config.ff_hidden_dim)),
layer_config(config) {}
MatStorageT<T> input;
@ -56,12 +62,12 @@ struct ForwardLayer {
template <typename T>
struct ForwardPass {
ForwardPass(const ModelConfig& config)
: final_layer_output("final_layer_output", config.seq_len,
config.model_dim),
final_norm_output("final_norm_output", config.seq_len,
config.model_dim),
logits("logits", config.seq_len, config.vocab_size),
probs("probs", config.seq_len, config.vocab_size),
: final_layer_output(
MakePacked<T>("fin_layer_out", config.seq_len, config.model_dim)),
final_norm_output(
MakePacked<T>("fin_norm_out", config.seq_len, config.model_dim)),
logits(MakePacked<T>("logits", config.seq_len, config.vocab_size)),
probs(MakePacked<T>("probs", config.seq_len, config.vocab_size)),
weights_config(config) {
for (const auto& layer_config : config.layer_configs) {
layers.emplace_back(layer_config, config.seq_len);

View File

@ -128,7 +128,7 @@ static HWY_NOINLINE void SoftmaxVJP(const float* HWY_RESTRICT forward,
HWY_ATTR { return hn::Mul(y, hn::Sub(v, offset)); });
}
static HWY_NOINLINE void RMSNormVJP(
static HWY_NOINLINE HWY_MAYBE_UNUSED void RMSNormVJP(
const float* HWY_RESTRICT weights, const float* HWY_RESTRICT x,
const float* HWY_RESTRICT v, size_t model_dim, size_t num_tokens,
float* HWY_RESTRICT grad_w, float* HWY_RESTRICT grad_x,
@ -153,10 +153,9 @@ static HWY_NOINLINE void RMSNormVJP(
}
}
static HWY_NOINLINE void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt,
const float scaling, const float* HWY_RESTRICT v,
float* HWY_RESTRICT grad, size_t model_dim) {
static HWY_NOINLINE HWY_MAYBE_UNUSED void InputEmbeddingVJP(
const float* weights, const std::vector<int>& prompt, const float scaling,
const float* HWY_RESTRICT v, float* HWY_RESTRICT grad, size_t model_dim) {
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
@ -182,17 +181,18 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
static_cast<float>(1.0 / sqrt(static_cast<double>(qkv_dim)));
HWY_ASSERT(num_tokens <= seq_len);
MatMulVJP(weights.linear_w.data(), forward.ffw_hidden_gated.data(),
MatMulVJP(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(),
next_layer_grad, ff_hidden_dim, model_dim, num_tokens,
grad.linear_w.data(), backward.ffw_hidden_gated.data(), pool);
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * ff_hidden_dim * 2;
const float* HWY_RESTRICT f_out = forward.ffw_hidden.data() + hidden_offset;
const float* HWY_RESTRICT f_out =
forward.ffw_hidden.Packed() + hidden_offset;
const float* HWY_RESTRICT f_out_mul = f_out + ff_hidden_dim;
const float* HWY_RESTRICT b_out_gated =
backward.ffw_hidden_gated.data() + pos * ff_hidden_dim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.data() + hidden_offset;
backward.ffw_hidden_gated.Packed() + pos * ff_hidden_dim;
float* HWY_RESTRICT b_out = backward.ffw_hidden.Packed() + hidden_offset;
float* HWY_RESTRICT b_out_mul = b_out + ff_hidden_dim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
@ -206,38 +206,39 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
}
}
MatMulVJP(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), model_dim, ff_hidden_dim * 2,
num_tokens, grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), pool);
RMSNormVJP(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(), model_dim, num_tokens,
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
pool);
MatMulVJP(weights.gating_einsum_w.Packed(),
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
model_dim, ff_hidden_dim * 2, num_tokens,
grad.gating_einsum_w.Packed(), backward.bf_pre_ffw_rms_out.Packed(),
pool);
RMSNormVJP(
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens,
grad.pre_ffw_norm_scale.Packed(), backward.attention_out.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(next_layer_grad + pos * model_dim,
backward.attention_out.data() + pos * model_dim, model_dim);
backward.attention_out.Packed() + pos * model_dim, model_dim);
}
backward.qkv.ZeroInit();
ZeroInit(backward.qkv);
MultiHeadMatMulVJP(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(), heads, qkv_dim, model_dim,
num_tokens, grad.attn_vec_einsum_w.data(),
backward.att_out.data(), pool);
MultiHeadMatMulVJP(
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
backward.attention_out.Packed(), heads, qkv_dim, model_dim, num_tokens,
grad.attn_vec_einsum_w.Packed(), backward.att_out.Packed(), pool);
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
const float* HWY_RESTRICT b_att_out =
backward.att_out.data() + (pos * heads + head) * qkv_dim;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
backward.att_out.Packed() + (pos * heads + head) * qkv_dim;
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t v2offs = (pos2 * (heads + 2) + heads + 1) * qkv_dim;
const float* HWY_RESTRICT f_v2 = forward.qkv.data() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.data() + v2offs;
const float* HWY_RESTRICT f_v2 = forward.qkv.Packed() + v2offs;
float* HWY_RESTRICT b_v2 = backward.qkv.Packed() + v2offs;
b_head_att[pos2] = Dot(b_att_out, f_v2, qkv_dim);
MulByConstAndAdd(f_head_att[pos2], b_att_out, b_v2, qkv_dim);
}
@ -247,8 +248,8 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t aoffset = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_head_att = forward.att.data() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.data() + aoffset;
const float* HWY_RESTRICT f_head_att = forward.att.Packed() + aoffset;
float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffset;
SoftmaxVJP(f_head_att, b_head_att, pos + 1);
}
}
@ -257,13 +258,13 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t qoffs = (pos * (heads + 2) + head) * qkv_dim;
const size_t aoffs = head * seq_len + pos * heads * seq_len;
const float* HWY_RESTRICT f_q = forward.qkv.data() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.data() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.data() + qoffs;
const float* HWY_RESTRICT f_q = forward.qkv.Packed() + qoffs;
const float* HWY_RESTRICT b_head_att = backward.att.Packed() + aoffs;
float* HWY_RESTRICT b_q = backward.qkv.Packed() + qoffs;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t k2offs = (pos2 * (heads + 2) + heads) * qkv_dim;
const float* HWY_RESTRICT f_k2 = forward.qkv.data() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.data() + k2offs;
const float* HWY_RESTRICT f_k2 = forward.qkv.Packed() + k2offs;
float* HWY_RESTRICT b_k2 = backward.qkv.Packed() + k2offs;
MulByConstAndAdd(b_head_att[pos2], f_k2, b_q, qkv_dim);
MulByConstAndAdd(b_head_att[pos2], f_q, b_k2, qkv_dim);
}
@ -272,28 +273,30 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
for (int pos = 0; pos < static_cast<int>(num_tokens); ++pos) {
float* HWY_RESTRICT b_kv =
backward.qkv.data() + (pos * (heads + 2) + heads) * qkv_dim;
backward.qkv.Packed() + (pos * (heads + 2) + heads) * qkv_dim;
Rope(b_kv, qkv_dim, inv_timescale.Const(), -pos);
}
for (size_t head = 0; head < heads; ++head) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT b_q =
backward.qkv.data() + (pos * (heads + 2) + head) * qkv_dim;
backward.qkv.Packed() + (pos * (heads + 2) + head) * qkv_dim;
MulByConst(query_scale, b_q, qkv_dim);
Rope(b_q, qkv_dim, inv_timescale.Const(), -pos);
}
}
MatMulVJP(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), model_dim, (heads + 2) * qkv_dim, num_tokens,
grad.qkv_einsum_w.data(), backward.pre_att_rms_out.data(), pool);
RMSNormVJP(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(), model_dim, num_tokens,
grad.pre_attention_norm_scale.data(), backward.input.data(), pool);
MatMulVJP(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
backward.qkv.Packed(), model_dim, (heads + 2) * qkv_dim, num_tokens,
grad.qkv_einsum_w.Packed(), backward.pre_att_rms_out.Packed(),
pool);
RMSNormVJP(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
backward.pre_att_rms_out.Packed(), model_dim, num_tokens,
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(backward.attention_out.data() + pos * model_dim,
backward.input.data() + pos * model_dim, model_dim);
AddFrom(backward.attention_out.Packed() + pos * model_dim,
backward.input.Packed() + pos * model_dim, model_dim);
}
}
@ -353,47 +356,48 @@ void CrossEntropyLossBackwardPassInl(const Prompt& prompt,
HWY_DASSERT(prompt.context_size < prompt.tokens.size());
const size_t num_tokens = prompt.tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
kVocabSize);
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftmaxVJP(forward.probs.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize,
kVocabSize);
SoftmaxVJP(forward.probs.Packed() + pos * kVocabSize,
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
}
if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
SoftcapVJP(config.final_cap, forward.logits.data() + pos * kVocabSize,
backward.logits.data() + pos * kVocabSize, kVocabSize);
SoftcapVJP(config.final_cap, forward.logits.Packed() + pos * kVocabSize,
backward.logits.Packed() + pos * kVocabSize, kVocabSize);
}
}
MatMulVJP(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(), backward.logits.data(), model_dim,
kVocabSize, num_tokens, grad.embedder_input_embedding.data(),
backward.final_norm_output.data(), pool);
MatMulVJP(weights.embedder_input_embedding.Packed(),
forward.final_norm_output.Packed(), backward.logits.Packed(),
model_dim, kVocabSize, num_tokens,
grad.embedder_input_embedding.Packed(),
backward.final_norm_output.Packed(), pool);
RMSNormVJP(weights.final_norm_scale.data(), forward.final_layer_output.data(),
backward.final_norm_output.data(), model_dim, num_tokens,
grad.final_norm_scale.data(), backward.final_layer_output.data(),
pool);
RMSNormVJP(weights.final_norm_scale.Packed(),
forward.final_layer_output.Packed(),
backward.final_norm_output.Packed(), model_dim, num_tokens,
grad.final_norm_scale.Packed(),
backward.final_layer_output.Packed(), pool);
for (int layer = static_cast<int>(kLayers) - 1; layer >= 0; --layer) {
auto layer_config = config.layer_configs[layer];
// TODO(szabadka) Implement Griffin layer vjp.
HWY_ASSERT(layer_config.type == LayerAttentionType::kGemma);
float* next_layer_grad = layer + 1 < kLayers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
? backward.layers[layer + 1].input.Packed()
: backward.final_layer_output.Packed();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
num_tokens, *grad.GetLayer(layer), backward.layers[layer],
inv_timescale, pool);
}
InputEmbeddingVJP(weights.embedder_input_embedding.data(), prompt.tokens,
kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), model_dim);
InputEmbeddingVJP(weights.embedder_input_embedding.Packed(), prompt.tokens,
kEmbScaling, backward.layers[0].input.Packed(),
grad.embedder_input_embedding.Packed(), model_dim);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -17,9 +17,8 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to

View File

@ -19,7 +19,7 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {

View File

@ -211,62 +211,65 @@ void LayerVJP(const LayerWeightsPtrs<T>& weights,
const size_t kFFHiddenDim = layer_config.ff_hidden_dim;
const T kQueryScale = 1.0 / std::sqrt(T(qkv_dim));
MatMulVJPT(weights.linear_w.data(), forward.ffw_hidden_gated.data(), dy,
grad.linear_w.data(), backward.ffw_hidden_gated.data(), model_dim,
kFFHiddenDim, num_tokens);
MatMulVJPT(weights.linear_w.Packed(), forward.ffw_hidden_gated.Packed(), dy,
grad.linear_w.Packed(), backward.ffw_hidden_gated.Packed(),
model_dim, kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.data(), backward.ffw_hidden_gated.data(),
backward.ffw_hidden.data(), kFFHiddenDim, num_tokens);
GatedGeluVJP(forward.ffw_hidden.Packed(), backward.ffw_hidden_gated.Packed(),
backward.ffw_hidden.Packed(), kFFHiddenDim, num_tokens);
MatMulVJPT(weights.gating_einsum_w.data(), forward.bf_pre_ffw_rms_out.data(),
backward.ffw_hidden.data(), grad.gating_einsum_w.data(),
backward.bf_pre_ffw_rms_out.data(), kFFHiddenDim * 2, model_dim,
MatMulVJPT(weights.gating_einsum_w.Packed(),
forward.bf_pre_ffw_rms_out.Packed(), backward.ffw_hidden.Packed(),
grad.gating_einsum_w.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), kFFHiddenDim * 2, model_dim,
num_tokens);
RMSNormVJPT(weights.pre_ffw_norm_scale.data(), forward.attention_out.data(),
backward.bf_pre_ffw_rms_out.data(),
grad.pre_ffw_norm_scale.data(), backward.attention_out.data(),
model_dim, num_tokens);
RMSNormVJPT(
weights.pre_ffw_norm_scale.Packed(), forward.attention_out.Packed(),
backward.bf_pre_ffw_rms_out.Packed(), grad.pre_ffw_norm_scale.Packed(),
backward.attention_out.Packed(), model_dim, num_tokens);
AddFromT(dy, backward.attention_out.data(), num_tokens * model_dim);
AddFromT(dy, backward.attention_out.Packed(), num_tokens * model_dim);
MultiHeadMatMulVJPT(weights.attn_vec_einsum_w.data(), forward.att_out.data(),
backward.attention_out.data(),
grad.attn_vec_einsum_w.data(), backward.att_out.data(),
kHeads, model_dim, qkv_dim, num_tokens);
MultiHeadMatMulVJPT(
weights.attn_vec_einsum_w.Packed(), forward.att_out.Packed(),
backward.attention_out.Packed(), grad.attn_vec_einsum_w.Packed(),
backward.att_out.Packed(), kHeads, model_dim, qkv_dim, num_tokens);
MixByAttentionVJP(forward.qkv.data(), forward.att.data(),
backward.att_out.data(), backward.qkv.data(),
backward.att.data(), num_tokens, kHeads, qkv_dim, seq_len);
MaskedSoftmaxVJPT(forward.att.data(), backward.att.data(), num_tokens, kHeads,
MixByAttentionVJP(forward.qkv.Packed(), forward.att.Packed(),
backward.att_out.Packed(), backward.qkv.Packed(),
backward.att.Packed(), num_tokens, kHeads, qkv_dim,
seq_len);
MaskedAttentionVJP(forward.qkv.data(), backward.att.data(),
backward.qkv.data(), num_tokens, kHeads, qkv_dim, seq_len);
MaskedSoftmaxVJPT(forward.att.Packed(), backward.att.Packed(), num_tokens,
kHeads, seq_len);
MaskedAttentionVJP(forward.qkv.Packed(), backward.att.Packed(),
backward.qkv.Packed(), num_tokens, kHeads, qkv_dim,
seq_len);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
MulByConstT(kQueryScale, qkv, kHeads * qkv_dim);
}
for (int pos = 0; pos < num_tokens; ++pos) {
T* qkv = backward.qkv.data() + pos * (kHeads + 2) * qkv_dim;
T* qkv = backward.qkv.Packed() + pos * (kHeads + 2) * qkv_dim;
for (size_t h = 0; h <= kHeads; ++h) {
Rope(qkv + h * qkv_dim, qkv_dim, -pos);
}
}
MatMulVJPT(weights.qkv_einsum_w.data(), forward.pre_att_rms_out.data(),
backward.qkv.data(), grad.qkv_einsum_w.data(),
backward.pre_att_rms_out.data(), (kHeads + 2) * qkv_dim, model_dim,
num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.data(), forward.input.data(),
backward.pre_att_rms_out.data(),
grad.pre_attention_norm_scale.data(), backward.input.data(),
MatMulVJPT(weights.qkv_einsum_w.Packed(), forward.pre_att_rms_out.Packed(),
backward.qkv.Packed(), grad.qkv_einsum_w.Packed(),
backward.pre_att_rms_out.Packed(), (kHeads + 2) * qkv_dim,
model_dim, num_tokens);
RMSNormVJPT(weights.pre_attention_norm_scale.Packed(), forward.input.Packed(),
backward.pre_att_rms_out.Packed(),
grad.pre_attention_norm_scale.Packed(), backward.input.Packed(),
model_dim, num_tokens);
AddFromT(backward.attention_out.data(), backward.input.data(),
AddFromT(backward.attention_out.Packed(), backward.input.Packed(),
num_tokens * model_dim);
}
@ -307,41 +310,42 @@ void CrossEntropyLossBackwardPass(const Prompt& prompt,
const std::vector<int> tokens = prompt.tokens;
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
CrossEntropyLossGrad(forward.probs.data(), backward.logits.data(), prompt,
CrossEntropyLossGrad(forward.probs.Packed(), backward.logits.Packed(), prompt,
vocab_size);
SoftmaxVJPT(forward.probs.data(), backward.logits.data(), vocab_size,
SoftmaxVJPT(forward.probs.Packed(), backward.logits.Packed(), vocab_size,
num_tokens);
if (config.final_cap > 0.0f) {
for (size_t i = 0; i < num_tokens; ++i) {
SoftcapVJPT(config.final_cap, forward.logits.data() + i * vocab_size,
backward.logits.data() + i * vocab_size, vocab_size);
SoftcapVJPT(config.final_cap, forward.logits.Packed() + i * vocab_size,
backward.logits.Packed() + i * vocab_size, vocab_size);
}
}
MatMulVJPT(
weights.embedder_input_embedding.data(), forward.final_norm_output.data(),
backward.logits.data(), grad.embedder_input_embedding.data(),
backward.final_norm_output.data(), vocab_size, model_dim, num_tokens);
MatMulVJPT(weights.embedder_input_embedding.Packed(),
forward.final_norm_output.Packed(), backward.logits.Packed(),
grad.embedder_input_embedding.Packed(),
backward.final_norm_output.Packed(), vocab_size, model_dim,
num_tokens);
RMSNormVJPT(weights.final_norm_scale.data(),
forward.final_layer_output.data(),
backward.final_norm_output.data(), grad.final_norm_scale.data(),
backward.final_layer_output.data(), model_dim, num_tokens);
RMSNormVJPT(
weights.final_norm_scale.Packed(), forward.final_layer_output.Packed(),
backward.final_norm_output.Packed(), grad.final_norm_scale.Packed(),
backward.final_layer_output.Packed(), model_dim, num_tokens);
for (int layer = static_cast<int>(layers) - 1; layer >= 0; --layer) {
T* next_layer_grad = layer + 1 < layers
? backward.layers[layer + 1].input.data()
: backward.final_layer_output.data();
? backward.layers[layer + 1].input.Packed()
: backward.final_layer_output.Packed();
LayerVJP(*weights.GetLayer(layer), forward.layers[layer], next_layer_grad,
*grad.GetLayer(layer), backward.layers[layer], num_tokens);
}
const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbeddingVJPT(weights.embedder_input_embedding.data(), tokens,
kEmbScaling, backward.layers[0].input.data(),
grad.embedder_input_embedding.data(), model_dim);
InputEmbeddingVJPT(weights.embedder_input_embedding.Packed(), tokens,
kEmbScaling, backward.layers[0].input.Packed(),
grad.embedder_input_embedding.Packed(), model_dim);
}
} // namespace gcpp

View File

@ -31,9 +31,9 @@
#include "backprop/prompt.h"
#include "backprop/sampler.h"
#include "backprop/test_util.h"
#include "compression/compress.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/mat.h"
namespace gcpp {
@ -44,14 +44,14 @@ TEST(BackPropTest, MatMulVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", kRows, kCols);
MatStorageT<T> x("x", kTokens, kCols);
MatStorageT<T> grad("grad", kRows, kCols);
MatStorageT<T> dx("dx", kTokens, kCols);
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
MatStorageT<TC> c_x("c_x", kTokens, kCols);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
MatStorageT<T> dy("dy", kTokens, kRows);
auto weights = MakePacked<T>("weights", kRows, kCols);
auto x = MakePacked<T>("x", kTokens, kCols);
auto grad = MakePacked<T>("grad", kRows, kCols);
auto dx = MakePacked<T>("dx", kTokens, kCols);
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
auto dy = MakePacked<T>("dy", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
@ -60,12 +60,13 @@ TEST(BackPropTest, MatMulVJP) {
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
kTokens);
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
};
grad.ZeroInit();
MatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
kRows, kCols, kTokens);
ZeroInit(grad);
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
dx.Packed(), kRows, kCols, kTokens);
TestGradient(dx, c_x, func, 1e-11, 1e-12, __LINE__);
TestGradient(grad, c_weights, func, 1e-14, 1e-12, __LINE__);
}
@ -79,14 +80,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", kRows, kCols * kHeads);
MatStorageT<T> x("x", kTokens, kCols * kHeads);
MatStorageT<T> grad("grad", kRows, kCols * kHeads);
MatStorageT<T> dx("dx", kTokens, kCols * kHeads);
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
MatStorageT<T> dy("dy", kTokens, kRows);
auto weights = MakePacked<T>("weights", kRows, kCols * kHeads);
auto x = MakePacked<T>("x", kTokens, kCols * kHeads);
auto grad = MakePacked<T>("grad", kRows, kCols * kHeads);
auto dx = MakePacked<T>("dx", kTokens, kCols * kHeads);
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
auto dy = MakePacked<T>("dy", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
@ -95,13 +96,14 @@ TEST(BackPropTest, MultiHeadMatMulVJP) {
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
kRows, kCols, kTokens);
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
};
grad.ZeroInit();
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad.data(),
dx.data(), kHeads, kRows, kCols, kTokens);
ZeroInit(grad);
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
grad.Packed(), dx.Packed(), kHeads, kRows, kCols,
kTokens);
TestGradient(dx, c_x, func, 1e-15, 1e-13, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-13, __LINE__);
}
@ -113,14 +115,14 @@ TEST(BackPropTest, RMSNormVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", N, 1);
MatStorageT<T> grad("grad", N, 1);
MatStorageT<T> x("x", K, N);
MatStorageT<T> dx("dx", K, N);
MatStorageT<T> dy("dy", K, N);
MatStorageT<TC> c_weights("c_weights", N, 1);
MatStorageT<TC> c_x("c_x", K, N);
MatStorageT<TC> c_y("c_y", K, N);
auto weights = MakePacked<T>("weights", N, 1);
auto grad = MakePacked<T>("grad", N, 1);
auto x = MakePacked<T>("x", K, N);
auto dx = MakePacked<T>("dx", K, N);
auto dy = MakePacked<T>("dy", K, N);
auto c_weights = MakePacked<TC>("c_weights", N, 1);
auto c_x = MakePacked<TC>("c_x", K, N);
auto c_y = MakePacked<TC>("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0 * (1 << iter), gen);
@ -129,12 +131,12 @@ TEST(BackPropTest, RMSNormVJP) {
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
return DotT(dy.Packed(), c_y.Packed(), K * N);
};
grad.ZeroInit();
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad.data(), dx.data(),
N, K);
ZeroInit(grad);
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad.Packed(),
dx.Packed(), N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
TestGradient(grad, c_weights, func, 1e-15, 1e-14, __LINE__);
}
@ -145,24 +147,24 @@ TEST(BackPropTest, SoftmaxVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", N, 1);
MatStorageT<T> dx("dx", N, 1);
MatStorageT<T> dy("dy", N, 1);
MatStorageT<TC> c_x("c_x", N, 1);
MatStorageT<TC> c_y("c_y", N, 1);
auto x = MakePacked<T>("x", N, 1);
auto dx = MakePacked<T>("dx", N, 1);
auto dy = MakePacked<T>("dy", N, 1);
auto c_x = MakePacked<TC>("c_x", N, 1);
auto c_y = MakePacked<TC>("c_y", N, 1);
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), c_x.SizeBytes());
Softmax(c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
CopyMat(c_x, c_y);
Softmax(c_y.Packed(), N);
return DotT(dy.Packed(), c_y.Packed(), N);
};
Softmax(x.data(), N);
memcpy(dx.data(), dy.data(), dx.SizeBytes());
SoftmaxVJPT(x.data(), dx.data(), N);
Softmax(x.Packed(), N);
CopyMat(dy, dx);
SoftmaxVJPT(x.Packed(), dx.Packed(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
@ -175,26 +177,25 @@ TEST(BackPropTest, MaskedSoftmaxVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", N, 1);
MatStorageT<T> dy("dy", N, 1);
MatStorageT<T> dx("dx", N, 1);
MatStorageT<TC> c_x("c_x", N, 1);
MatStorageT<TC> c_y("c_y", N, 1);
dx.ZeroInit();
auto x = MakePacked<T>("x", N, 1);
auto dy = MakePacked<T>("dy", N, 1);
auto dx = MakePacked<T>("dx", N, 1);
auto c_x = MakePacked<TC>("c_x", N, 1);
auto c_y = MakePacked<TC>("c_y", N, 1);
ZeroInit(dx);
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0 * (1 << iter), gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(),
kTokens * kHeads * kSeqLen * sizeof(c_x.At(0)));
MaskedSoftmax(c_y.data(), kTokens, kHeads, kSeqLen);
return DotT(dy.data(), c_y.data(), N);
CopyMat(c_x, c_y);
MaskedSoftmax(c_y.Packed(), kTokens, kHeads, kSeqLen);
return DotT(dy.Packed(), c_y.Packed(), N);
};
MaskedSoftmax(x.data(), kTokens, kHeads, kSeqLen);
memcpy(dx.data(), dy.data(), kTokens * kHeads * kSeqLen * sizeof(dx.At(0)));
MaskedSoftmaxVJPT(x.data(), dx.data(), kTokens, kHeads, kSeqLen);
MaskedSoftmax(x.Packed(), kTokens, kHeads, kSeqLen);
CopyMat(dy, dx);
MaskedSoftmaxVJPT(x.Packed(), dx.Packed(), kTokens, kHeads, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
@ -204,11 +205,11 @@ TEST(BackPropTest, SoftcapVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", N, 1);
MatStorageT<T> dx("dx", N, 1);
MatStorageT<T> dy("dy", N, 1);
MatStorageT<TC> c_x("c_x", N, 1);
MatStorageT<TC> c_y("c_y", N, 1);
auto x = MakePacked<T>("x", N, 1);
auto dx = MakePacked<T>("dx", N, 1);
auto dy = MakePacked<T>("dy", N, 1);
auto c_x = MakePacked<TC>("c_x", N, 1);
auto c_y = MakePacked<TC>("c_y", N, 1);
constexpr float kCap = 30.0f;
for (int iter = 0; iter < 10; ++iter) {
@ -216,13 +217,13 @@ TEST(BackPropTest, SoftcapVJP) {
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
memcpy(c_y.data(), c_x.data(), N * sizeof(c_x.At(0)));
Softcap(kCap, c_y.data(), N);
return DotT(dy.data(), c_y.data(), N);
CopyMat(c_x, c_y);
Softcap(kCap, c_y.Packed(), N);
return DotT(dy.Packed(), c_y.Packed(), N);
};
Softcap(kCap, x.data(), N);
memcpy(dx.data(), dy.data(), dx.SizeBytes());
SoftcapVJPT(kCap, x.data(), dx.data(), N);
Softcap(kCap, x.Packed(), N);
CopyMat(dy, dx);
SoftcapVJPT(kCap, x.Packed(), dx.Packed(), N);
TestGradient(dx, c_x, func, 1e-15, 1e-14, __LINE__);
}
}
@ -233,9 +234,9 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", K, V);
MatStorageT<T> dx("dx", K, V);
MatStorageT<TC> c_x("c_x", K, V);
auto x = MakePacked<T>("x", K, V);
auto dx = MakePacked<T>("dx", K, V);
auto c_x = MakePacked<TC>("c_x", K, V);
Prompt prompt;
prompt.tokens = { 0, 1, 2, 3, 0, 3, 2, 1, 0 };
@ -243,13 +244,11 @@ TEST(BackPropTest, CrossEntropyLossGrad) {
for (int iter = 0; iter < 10; ++iter) {
prompt.context_size = 1 + (iter % 6);
RandInit(x, 1.0 * (1 << iter), gen);
Softcap(kCap, x.data(), V * K);
Softmax(x.data(), V, K);
CrossEntropyLossGrad(x.data(), dx.data(), prompt, V);
Softcap(kCap, x.Packed(), V * K);
Softmax(x.Packed(), V, K);
CrossEntropyLossGrad(x.Packed(), dx.Packed(), prompt, V);
Complexify(x, c_x);
auto func = [&]() {
return CrossEntropyLoss(c_x.data(), prompt, V);
};
auto func = [&]() { return CrossEntropyLoss(c_x.Packed(), prompt, V); };
TestGradient(dx, c_x, func, 1e-100, 1e-15, __LINE__);
}
}
@ -260,21 +259,21 @@ TEST(BackPropTest, GatedGeluVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", K, 2 * N);
MatStorageT<T> dx("dx", K, 2 * N);
MatStorageT<T> dy("dy", K, N);
MatStorageT<TC> c_x("c_x", K, 2 * N);
MatStorageT<TC> c_y("c_y", K, N);
auto x = MakePacked<T>("x", K, 2 * N);
auto dx = MakePacked<T>("dx", K, 2 * N);
auto dy = MakePacked<T>("dy", K, N);
auto c_x = MakePacked<TC>("c_x", K, 2 * N);
auto c_y = MakePacked<TC>("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
GatedGelu(c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), N * K);
GatedGelu(c_x.Packed(), c_y.Packed(), N, K);
return DotT(dy.Packed(), c_y.Packed(), N * K);
};
GatedGeluVJP(x.data(), dy.data(), dx.data(), N, K);
GatedGeluVJP(x.Packed(), dy.Packed(), dx.Packed(), N, K);
TestGradient(dx, c_x, func, 1e-15, 1e-15, __LINE__);
}
}
@ -289,25 +288,25 @@ TEST(BackPropTest, MaskedAttentionVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> x("x", kQKVSize, 1);
MatStorageT<T> dx("dx", kQKVSize, 1);
MatStorageT<T> dy("dy", kOutSize, 1);
MatStorageT<TC> c_x("c_x", kQKVSize, 1);
MatStorageT<TC> c_y("c_y", kOutSize, 1);
dx.ZeroInit();
c_y.ZeroInit();
auto x = MakePacked<T>("x", kQKVSize, 1);
auto dx = MakePacked<T>("dx", kQKVSize, 1);
auto dy = MakePacked<T>("dy", kOutSize, 1);
auto c_x = MakePacked<TC>("c_x", kQKVSize, 1);
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
ZeroInit(dx);
ZeroInit(c_y);
for (int iter = 0; iter < 10; ++iter) {
RandInit(x, 1.0, gen);
Complexify(x, c_x);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MaskedAttention(c_x.data(), c_y.data(), kTokens, kHeads, kQKVDim,
MaskedAttention(c_x.Packed(), c_y.Packed(), kTokens, kHeads, kQKVDim,
kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
};
MaskedAttentionVJP(x.data(), dy.data(), dx.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
MaskedAttentionVJP(x.Packed(), dy.Packed(), dx.Packed(), kTokens, kHeads,
kQKVDim, kSeqLen);
TestGradient(dx, c_x, func, 1e-14, 1e-15, __LINE__);
}
}
@ -323,17 +322,17 @@ TEST(BackPropTest, MixByAttentionVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> qkv("qkv", kQKVSize, 1);
MatStorageT<T> dqkv("dqkv", kQKVSize, 1);
MatStorageT<T> attn("attn", kAttnSize, 1);
MatStorageT<T> dattn("dattn", kAttnSize, 1);
MatStorageT<T> dy("dy", kOutSize, 1);
MatStorageT<TC> c_qkv("c_qkv", kQKVSize, 1);
MatStorageT<TC> c_attn("c_attn", kAttnSize, 1);
MatStorageT<TC> c_y("c_y", kOutSize, 1);
dqkv.ZeroInit();
dattn.ZeroInit();
c_y.ZeroInit();
auto qkv = MakePacked<T>("qkv", kQKVSize, 1);
auto dqkv = MakePacked<T>("dqkv", kQKVSize, 1);
auto attn = MakePacked<T>("attn", kAttnSize, 1);
auto dattn = MakePacked<T>("dattn", kAttnSize, 1);
auto dy = MakePacked<T>("dy", kOutSize, 1);
auto c_qkv = MakePacked<TC>("c_qkv", kQKVSize, 1);
auto c_attn = MakePacked<TC>("c_attn", kAttnSize, 1);
auto c_y = MakePacked<TC>("c_y", kOutSize, 1);
ZeroInit(dqkv);
ZeroInit(dattn);
ZeroInit(c_y);
for (int iter = 0; iter < 10; ++iter) {
RandInit(qkv, 1.0, gen);
@ -342,12 +341,12 @@ TEST(BackPropTest, MixByAttentionVJP) {
Complexify(attn, c_attn);
RandInit(dy, 1.0, gen);
auto func = [&]() {
MixByAttention(c_qkv.data(), c_attn.data(), c_y.data(),
kTokens, kHeads, kQKVDim, kSeqLen);
return DotT(dy.data(), c_y.data(), kOutSize);
MixByAttention(c_qkv.Packed(), c_attn.Packed(), c_y.Packed(), kTokens,
kHeads, kQKVDim, kSeqLen);
return DotT(dy.Packed(), c_y.Packed(), kOutSize);
};
MixByAttentionVJP(qkv.data(), attn.data(), dy.data(), dqkv.data(),
dattn.data(), kTokens, kHeads, kQKVDim, kSeqLen);
MixByAttentionVJP(qkv.Packed(), attn.Packed(), dy.Packed(), dqkv.Packed(),
dattn.Packed(), kTokens, kHeads, kQKVDim, kSeqLen);
TestGradient(dqkv, c_qkv, func, 1e-14, 1e-15, __LINE__);
TestGradient(dattn, c_attn, func, 1e-14, 1e-15, __LINE__);
}
@ -360,11 +359,11 @@ TEST(BackPropTest, InputEmbeddingVJP) {
std::mt19937 gen(42);
using T = double;
using TC = std::complex<T>;
MatStorageT<T> weights("weights", kVocabSize, kModelDim);
MatStorageT<T> grad("grad", kVocabSize, kModelDim);
MatStorageT<T> dy("dy", kSeqLen, kModelDim);
MatStorageT<TC> c_weights("c_weights", kVocabSize, kModelDim);
MatStorageT<TC> c_y("c_y", kSeqLen, kModelDim);
auto weights = MakePacked<T>("weights", kVocabSize, kModelDim);
auto grad = MakePacked<T>("grad", kVocabSize, kModelDim);
auto dy = MakePacked<T>("dy", kSeqLen, kModelDim);
auto c_weights = MakePacked<TC>("c_weights", kVocabSize, kModelDim);
auto c_y = MakePacked<TC>("c_y", kSeqLen, kModelDim);
std::vector<int> tokens = { 0, 1, 2, 3, 0, 1, 2 };
size_t num_tokens = tokens.size() - 1;
@ -373,12 +372,13 @@ TEST(BackPropTest, InputEmbeddingVJP) {
RandInit(dy, 1.0, gen);
Complexify(weights, c_weights);
auto func = [&]() {
InputEmbedding(c_weights.data(), tokens, TC(3.0), c_y.data(), kModelDim);
return DotT(dy.data(), c_y.data(), num_tokens * kModelDim);
InputEmbedding(c_weights.Packed(), tokens, TC(3.0), c_y.Packed(),
kModelDim);
return DotT(dy.Packed(), c_y.Packed(), num_tokens * kModelDim);
};
grad.ZeroInit();
InputEmbeddingVJPT(weights.data(), tokens, 3.0, dy.data(), grad.data(),
kModelDim);
ZeroInit(grad);
InputEmbeddingVJPT(weights.Packed(), tokens, 3.0, dy.Packed(),
grad.Packed(), kModelDim);
TestGradient(grad, c_weights, func, 1e-16, 1e-14, __LINE__);
}
}
@ -410,8 +410,7 @@ TEST(BackPropTest, LayerVJP) {
using T = double;
using TC = std::complex<T>;
ModelConfig config = TestConfig();
TensorIndex tensor_index(config, /*llm_layer_idx=*/0, /*img_layer_idx=*/-1,
/*reshape_att=*/false);
const TensorIndex tensor_index = TensorIndexLLM(config, size_t{0});
const size_t kOutputSize = config.seq_len * config.model_dim;
LayerWeightsPtrs<T> weights(config.layer_configs[0], tensor_index);
LayerWeightsPtrs<T> grad(config.layer_configs[0], tensor_index);
@ -419,15 +418,15 @@ TEST(BackPropTest, LayerVJP) {
ForwardLayer<T> backward(config.layer_configs[0], config.seq_len);
LayerWeightsPtrs<TC> c_weights(config.layer_configs[0], tensor_index);
ForwardLayer<TC> c_forward(config.layer_configs[0], config.seq_len);
MatStorageT<T> y("y", kOutputSize, 1);
MatStorageT<T> dy("dy", kOutputSize, 1);
MatStorageT<TC> c_y("c_y", kOutputSize, 1);
auto y = MakePacked<T>("y", kOutputSize, 1);
auto dy = MakePacked<T>("dy", kOutputSize, 1);
auto c_y = MakePacked<TC>("c_y", kOutputSize, 1);
const size_t num_tokens = 3;
std::vector<MatStorage> layer_storage;
std::vector<MatOwner> layer_storage;
weights.Allocate(layer_storage);
grad.Allocate(layer_storage);
c_weights.Allocate(layer_storage);
backward.input.ZeroInit();
ZeroInit(backward.input);
for (size_t iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0, gen);
@ -436,12 +435,12 @@ TEST(BackPropTest, LayerVJP) {
Complexify(weights, c_weights);
Complexify(forward.input, c_forward.input);
auto func = [&]() {
ApplyLayer(c_weights, c_forward, num_tokens, c_y.data());
return DotT(dy.data(), c_y.data(), num_tokens * config.model_dim);
ApplyLayer(c_weights, c_forward, num_tokens, c_y.Packed());
return DotT(dy.Packed(), c_y.Packed(), num_tokens * config.model_dim);
};
grad.ZeroInit(/*layer_idx=*/0);
ApplyLayer(weights, forward, num_tokens, y.data());
LayerVJP(weights, forward, dy.data(), grad, backward, num_tokens);
ApplyLayer(weights, forward, num_tokens, y.Packed());
LayerVJP(weights, forward, dy.Packed(), grad, backward, num_tokens);
TestGradient(backward.input, c_forward.input, func, 1e-11, 5e-11,
__LINE__);
TestGradient(grad, c_weights, func, 1e-11);

View File

@ -33,8 +33,10 @@
#include "backprop/test_util.h"
#include "gemma/configs.h"
#include "ops/ops.h"
#include "util/threading.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
@ -46,33 +48,45 @@
// After highway.h
#include "backprop/backward-inl.h"
#include "backprop/forward-inl.h"
#include "compression/compress.h"
#include "ops/ops-inl.h"
#include "util/allocator.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
hwy::ThreadPool& ThreadHostileGetPool() {
// Assume this is only called at the top level, i.e. not in a thread. Then we
// can safely call `SetArgs` only once, because it would assert otherwise.
// This is preferable to calling `ThreadHostileInvalidate`, because we would
// repeat the topology initialization for every test.
if (!ThreadingContext2::IsInitialized()) {
gcpp::ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 8;
threading_args.pin = Tristate::kFalse;
ThreadingContext2::SetArgs(threading_args);
}
return ThreadingContext2::Get().pools.Pool();
}
void TestMatMulVJP() {
static const size_t kRows = 8;
static const size_t kCols = 64;
static const size_t kTokens = 5;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
hwy::ThreadPool& pool = ThreadHostileGetPool();
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols);
MatStorageT<float> x("x", kTokens, kCols);
MatStorageT<float> dy("dy", kTokens, kRows);
MatStorageT<float> grad("grad", kRows, kCols);
MatStorageT<float> dx("dx", kTokens, kCols);
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols);
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols);
auto weights = MakePacked<float>("weights", kRows, kCols);
auto x = MakePacked<float>("x", kTokens, kCols);
auto dy = MakePacked<float>("dy", kTokens, kRows);
auto grad = MakePacked<float>("grad", kRows, kCols);
auto dx = MakePacked<float>("dx", kTokens, kCols);
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols);
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols);
using TC = std::complex<double>;
MatStorageT<TC> c_weights("c_weights", kRows, kCols);
MatStorageT<TC> c_x("c_x", kTokens, kCols);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols);
auto c_x = MakePacked<TC>("c_x", kTokens, kCols);
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
@ -81,19 +95,20 @@ void TestMatMulVJP() {
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MatMulT(c_weights.data(), c_x.data(), c_y.data(), kRows, kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
MatMulT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kRows, kCols,
kTokens);
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
};
grad.ZeroInit();
MatMulVJP(weights.data(), x.data(), dy.data(), kCols, kRows, kTokens,
grad.data(), dx.data(), pools.Pool());
ZeroInit(grad);
MatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kCols, kRows, kTokens,
grad.Packed(), dx.Packed(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
grad_scalar.ZeroInit();
MatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kRows, kCols, kTokens);
ZeroInit(grad_scalar);
MatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
dx_scalar.Packed(), kRows, kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 1e-4, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
@ -104,21 +119,19 @@ void TestMultiHeadMatMulVJP() {
static const size_t kCols = 16;
static const size_t kHeads = 4;
static const size_t kTokens = 3;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
hwy::ThreadPool& pool = ThreadHostileGetPool();
std::mt19937 gen(42);
MatStorageT<float> weights("weights", kRows, kCols * kHeads);
MatStorageT<float> x("x", kTokens, kCols * kHeads);
MatStorageT<float> grad("grad", kRows, kCols * kHeads);
MatStorageT<float> dx("dx", kTokens, kCols * kHeads);
MatStorageT<float> dy("dy", kTokens, kRows);
MatStorageT<float> grad_scalar("grad_scalar", kRows, kCols * kHeads);
MatStorageT<float> dx_scalar("dx_scalar", kTokens, kCols * kHeads);
auto weights = MakePacked<float>("weights", kRows, kCols * kHeads);
auto x = MakePacked<float>("x", kTokens, kCols * kHeads);
auto grad = MakePacked<float>("grad", kRows, kCols * kHeads);
auto dx = MakePacked<float>("dx", kTokens, kCols * kHeads);
auto dy = MakePacked<float>("dy", kTokens, kRows);
auto grad_scalar = MakePacked<float>("grad_scalar", kRows, kCols * kHeads);
auto dx_scalar = MakePacked<float>("dx_scalar", kTokens, kCols * kHeads);
using TC = std::complex<double>;
MatStorageT<TC> c_weights("c_weights", kRows, kCols * kHeads);
MatStorageT<TC> c_x("c_x", kTokens, kCols * kHeads);
MatStorageT<TC> c_y("c_y", kTokens, kRows);
auto c_weights = MakePacked<TC>("c_weights", kRows, kCols * kHeads);
auto c_x = MakePacked<TC>("c_x", kTokens, kCols * kHeads);
auto c_y = MakePacked<TC>("c_y", kTokens, kRows);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
@ -127,20 +140,21 @@ void TestMultiHeadMatMulVJP() {
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
MultiHeadMatMul(c_weights.data(), c_x.data(), c_y.data(), kHeads, kRows,
kCols, kTokens);
return DotT(dy.data(), c_y.data(), kTokens * kRows);
MultiHeadMatMul(c_weights.Packed(), c_x.Packed(), c_y.Packed(), kHeads,
kRows, kCols, kTokens);
return DotT(dy.Packed(), c_y.Packed(), kTokens * kRows);
};
grad.ZeroInit();
MultiHeadMatMulVJP(weights.data(), x.data(), dy.data(), kHeads, kCols,
kRows, kTokens, grad.data(), dx.data(), pools.Pool());
ZeroInit(grad);
MultiHeadMatMulVJP(weights.Packed(), x.Packed(), dy.Packed(), kHeads, kCols,
kRows, kTokens, grad.Packed(), dx.Packed(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
grad_scalar.ZeroInit();
MultiHeadMatMulVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), kHeads, kRows, kCols, kTokens);
ZeroInit(grad_scalar);
MultiHeadMatMulVJPT(weights.Packed(), x.Packed(), dy.Packed(),
grad_scalar.Packed(), dx_scalar.Packed(), kHeads, kRows,
kCols, kTokens);
TestNear(dx, dx_scalar, 5e-5, 5e-5, __LINE__);
TestNear(grad, grad_scalar, 5e-5, 5e-5, __LINE__);
}
@ -149,21 +163,19 @@ void TestMultiHeadMatMulVJP() {
void TestRMSNormVJP() {
static const size_t K = 2;
static const size_t N = 64;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 8));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
hwy::ThreadPool& pool = ThreadHostileGetPool();
std::mt19937 gen(42);
MatStorageT<float> weights("weights", N, 1);
MatStorageT<float> x("x", K, N);
MatStorageT<float> grad("grad", N, 1);
MatStorageT<float> dx("dx", K, N);
MatStorageT<float> dy("dy", K, N);
MatStorageT<float> grad_scalar("grad_scalar", N, 1);
MatStorageT<float> dx_scalar("dx_scalar", K, N);
auto weights = MakePacked<float>("weights", N, 1);
auto x = MakePacked<float>("x", K, N);
auto grad = MakePacked<float>("grad", N, 1);
auto dx = MakePacked<float>("dx", K, N);
auto dy = MakePacked<float>("dy", K, N);
auto grad_scalar = MakePacked<float>("grad_scalar", N, 1);
auto dx_scalar = MakePacked<float>("dx_scalar", K, N);
using TC = std::complex<double>;
MatStorageT<TC> c_weights("c_weights", N, 1);
MatStorageT<TC> c_x("c_x", K, N);
MatStorageT<TC> c_y("c_y", K, N);
auto c_weights = MakePacked<TC>("c_weights", N, 1);
auto c_x = MakePacked<TC>("c_x", K, N);
auto c_y = MakePacked<TC>("c_y", K, N);
for (int iter = 0; iter < 10; ++iter) {
RandInit(weights, 1.0f * (1 << iter), gen);
@ -172,19 +184,19 @@ void TestRMSNormVJP() {
Complexify(weights, c_weights);
Complexify(x, c_x);
auto func = [&]() {
RMSNormT(c_weights.data(), c_x.data(), c_y.data(), N, K);
return DotT(dy.data(), c_y.data(), K * N);
RMSNormT(c_weights.Packed(), c_x.Packed(), c_y.Packed(), N, K);
return DotT(dy.Packed(), c_y.Packed(), K * N);
};
grad.ZeroInit();
RMSNormVJP(weights.data(), x.data(), dy.data(), N, K, grad.data(),
dx.data(), pools.Pool());
ZeroInit(grad);
RMSNormVJP(weights.Packed(), x.Packed(), dy.Packed(), N, K, grad.Packed(),
dx.Packed(), pool);
TestGradient(dx, c_x, func, 5e-5f, 5e-5f, __LINE__);
TestGradient(grad, c_weights, func, 5e-5f, 5e-5f, __LINE__);
grad_scalar.ZeroInit();
RMSNormVJPT(weights.data(), x.data(), dy.data(), grad_scalar.data(),
dx_scalar.data(), N, K);
ZeroInit(grad_scalar);
RMSNormVJPT(weights.Packed(), x.Packed(), dy.Packed(), grad_scalar.Packed(),
dx_scalar.Packed(), N, K);
TestNear(dx, dx_scalar, 0, 2e-5, __LINE__);
TestNear(grad, grad_scalar, 0, 2e-5, __LINE__);
}
@ -215,9 +227,7 @@ static ModelConfig TestConfig() {
void TestEndToEnd() {
std::mt19937 gen(42);
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(topology);
gcpp::NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
hwy::ThreadPool& pool = ThreadHostileGetPool();
ModelConfig config = TestConfig();
WeightsWrapper<float> weights(config);
WeightsWrapper<float> grad(config);
@ -232,7 +242,7 @@ void TestEndToEnd() {
std::vector<Prompt> batch = training_task.SampleBatch(3, gen);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
config.layer_configs[0].qkv_dim,
ThreadingContext2::Get().allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
for (const Prompt& prompt : batch) {
ReverseSequenceSampler::LogPrompt(prompt);
@ -242,13 +252,13 @@ void TestEndToEnd() {
float loss1 = CrossEntropyLossForwardPass(
prompt.tokens, prompt.context_size, weights.get(), forward1,
inv_timescale, pools.Pool());
inv_timescale, pool);
EXPECT_NEAR(loss1, loss0, std::abs(loss0) * 2e-5);
grad.ZeroInit();
CrossEntropyLossBackwardPassInl(prompt, weights.get(), forward1, grad.get(),
backward, inv_timescale, pools.Pool());
backward, inv_timescale, pool);
Complexify(weights.get(), c_weights.get());
auto func = [&]() {

View File

@ -20,7 +20,7 @@
#include <complex>
#include "compression/compress.h" // MatStorageT
#include "util/mat.h"
namespace gcpp {
@ -60,7 +60,9 @@ void MulByConstAndAddT(T c, const T* x, T* out, size_t N) {
template <typename T>
void MulByConstAndAddT(T c, const MatPtrT<T>& x, MatPtrT<T>& out) {
MulByConstAndAddT(c, x.data(), out.data(), x.NumElements());
for (size_t r = 0; r < x.Rows(); ++r) {
MulByConstAndAddT(c, x.Row(r), out.Row(r), x.Cols());
}
}
template<typename T>

View File

@ -28,6 +28,7 @@
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -50,16 +51,17 @@ HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename ArrayT>
void InputEmbedding(const ArrayT& weights, const std::vector<int>& prompt,
template <typename T>
void InputEmbedding(const MatPtrT<T>& weights, const std::vector<int>& prompt,
const float scaling, float* HWY_RESTRICT output,
size_t model_dim, size_t vocab_size) {
const hn::ScalableTag<float> df;
HWY_ASSERT(!prompt.empty());
for (size_t pos = 0; pos < prompt.size() - 1; ++pos) {
int token = prompt[pos];
DecompressAndZeroPad(df, MakeSpan(weights.data(), model_dim * vocab_size),
token * model_dim, output + pos * model_dim,
const auto span = weights.Span();
HWY_ASSERT(span.num == model_dim * vocab_size);
DecompressAndZeroPad(df, span, token * model_dim, output + pos * model_dim,
model_dim);
MulByConst(scaling, output + pos * model_dim, model_dim);
}
@ -109,27 +111,27 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
static_cast<float>(1.0 / sqrt(static_cast<double>(kQKVDim)));
HWY_ASSERT(num_tokens <= kSeqLen);
ApplyRMSNorm(weights.pre_attention_norm_scale.data(),
activations.input.data(), model_dim, num_tokens,
activations.pre_att_rms_out.data(), pool);
ApplyRMSNorm(weights.pre_attention_norm_scale.Packed(),
activations.input.Packed(), model_dim, num_tokens,
activations.pre_att_rms_out.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.qkv_einsum_w, 0, (kHeads + 2) * kQKVDim, model_dim,
activations.pre_att_rms_out.data() + pos * model_dim,
activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool);
activations.pre_att_rms_out.Packed() + pos * model_dim,
activations.qkv.Packed() + pos * (kHeads + 2) * kQKVDim, pool);
}
const size_t num_tasks = kHeads * num_tokens;
for (size_t pos = 0; pos < num_tokens; ++pos) {
float* HWY_RESTRICT k =
activations.qkv.data() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
activations.qkv.Packed() + (pos * (kHeads + 2) + kHeads) * kQKVDim;
Rope(k, kQKVDim, inv_timescale.Const(), pos);
}
pool.Run(0, num_tasks, [&](const uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
Rope(q, kQKVDim, inv_timescale.Const(), pos);
MulByConst(query_scale, q, kQKVDim);
});
@ -138,12 +140,12 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT q =
activations.qkv.data() + (pos * (kHeads + 2) + head) * kQKVDim;
activations.qkv.Packed() + (pos * (kHeads + 2) + head) * kQKVDim;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const float* HWY_RESTRICT k2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
activations.qkv.Packed() + (pos2 * (kHeads + 2) + kHeads) * kQKVDim;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
}
@ -153,7 +155,7 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
Softmax(head_att, pos + 1);
});
@ -161,51 +163,51 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
const size_t head = task % kHeads;
const size_t pos = task / kHeads;
const float* HWY_RESTRICT head_att =
activations.att.data() + (pos * kHeads + head) * kSeqLen;
activations.att.Packed() + (pos * kHeads + head) * kSeqLen;
float* HWY_RESTRICT att_out =
activations.att_out.data() + (pos * kHeads + head) * kQKVDim;
activations.att_out.Packed() + (pos * kHeads + head) * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
float* HWY_RESTRICT v2 =
activations.qkv.data() + (pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
float* HWY_RESTRICT v2 = activations.qkv.Packed() +
(pos2 * (kHeads + 2) + kHeads + 1) * kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
});
activations.attention_out.ZeroInit();
ZeroInit(activations.attention_out);
for (size_t pos = 0; pos < num_tokens; ++pos) {
for (size_t head = 0; head < kHeads; ++head) {
MatVec(
weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
kQKVDim,
activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim,
activations.att_post1.data() + pos * model_dim, pool);
AddFrom(activations.att_post1.data() + pos * model_dim,
activations.attention_out.data() + pos * model_dim, model_dim);
MatVec(weights.attn_vec_einsum_w, head * model_dim * kQKVDim, model_dim,
kQKVDim,
activations.att_out.Packed() + pos * kHeads * kQKVDim +
head * kQKVDim,
activations.att_post1.Packed() + pos * model_dim, pool);
AddFrom(activations.att_post1.Packed() + pos * model_dim,
activations.attention_out.Packed() + pos * model_dim, model_dim);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.input.data() + pos * model_dim,
activations.attention_out.data() + pos * model_dim, model_dim);
AddFrom(activations.input.Packed() + pos * model_dim,
activations.attention_out.Packed() + pos * model_dim, model_dim);
}
ApplyRMSNorm(weights.pre_ffw_norm_scale.data(),
activations.attention_out.data(), model_dim, num_tokens,
activations.bf_pre_ffw_rms_out.data(), pool);
ApplyRMSNorm(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(), model_dim, num_tokens,
activations.bf_pre_ffw_rms_out.Packed(), pool);
const size_t kFFHiddenDim = config.ff_hidden_dim;
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.gating_einsum_w, 0, kFFHiddenDim * 2, model_dim,
activations.bf_pre_ffw_rms_out.data() + pos * model_dim,
activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool);
activations.bf_pre_ffw_rms_out.Packed() + pos * model_dim,
activations.ffw_hidden.Packed() + pos * kFFHiddenDim * 2, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
const size_t hidden_offset = pos * kFFHiddenDim * 2;
const float* HWY_RESTRICT out =
activations.ffw_hidden.data() + hidden_offset;
activations.ffw_hidden.Packed() + hidden_offset;
const float* HWY_RESTRICT out_mul = out + kFFHiddenDim;
float* HWY_RESTRICT out_gated =
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim;
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim;
namespace hn = hwy::HWY_NAMESPACE;
using DF = hn::ScalableTag<float>;
DF df;
@ -217,11 +219,11 @@ void ApplyForwardLayer(const LayerWeightsPtrs<T>& weights,
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.linear_w, 0, model_dim, kFFHiddenDim,
activations.ffw_hidden_gated.data() + pos * kFFHiddenDim,
activations.ffw_hidden_gated.Packed() + pos * kFFHiddenDim,
output + pos * model_dim, pool);
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
AddFrom(activations.attention_out.data() + pos * model_dim,
AddFrom(activations.attention_out.Packed() + pos * model_dim,
output + pos * model_dim, model_dim);
}
}
@ -247,44 +249,43 @@ float CrossEntropyLossForwardPass(const std::vector<int>& prompt,
const size_t num_tokens = prompt.size() - 1;
InputEmbedding(weights.embedder_input_embedding, prompt, emb_scaling,
forward.layers[0].input.data(), model_dim, vocab_size);
forward.layers[0].input.Packed(), model_dim, vocab_size);
for (size_t layer = 0; layer < config.layer_configs.size(); ++layer) {
auto type = config.layer_configs[layer].type;
// TODO(szabadka) Implement Griffin layer.
HWY_ASSERT(type == LayerAttentionType::kGemma);
float* HWY_RESTRICT output = layer + 1 < layers
? forward.layers[layer + 1].input.data()
: forward.final_layer_output.data();
? forward.layers[layer + 1].input.Packed()
: forward.final_layer_output.Packed();
ApplyForwardLayer(*weights.GetLayer(layer), forward.layers[layer],
num_tokens, output, inv_timescale, pool);
}
ApplyRMSNorm(weights.final_norm_scale.data(),
forward.final_layer_output.data(), model_dim, num_tokens,
forward.final_norm_output.data(), pool);
ApplyRMSNorm(weights.final_norm_scale.Packed(),
forward.final_layer_output.Packed(), model_dim, num_tokens,
forward.final_norm_output.Packed(), pool);
for (size_t pos = 0; pos < num_tokens; ++pos) {
MatVec(weights.embedder_input_embedding, 0, vocab_size, model_dim,
forward.final_norm_output.data() + pos * model_dim,
forward.logits.data() + pos * vocab_size, pool);
forward.final_norm_output.Packed() + pos * model_dim,
forward.logits.Packed() + pos * vocab_size, pool);
}
if (config.final_cap > 0.0f) {
for (size_t pos = 0; pos < num_tokens; ++pos) {
LogitsSoftCap(config.final_cap, forward.logits.data() + pos * vocab_size,
vocab_size);
LogitsSoftCap(config.final_cap,
forward.logits.Packed() + pos * vocab_size, vocab_size);
}
}
hwy::CopyBytes(forward.logits.data(), forward.probs.data(),
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
CopyMat(forward.logits, forward.probs);
for (size_t pos = 0; pos < num_tokens; ++pos) {
Softmax(forward.probs.data() + pos * vocab_size, vocab_size);
Softmax(forward.probs.Packed() + pos * vocab_size, vocab_size);
}
return CrossEntropyLoss(forward.probs.data(), prompt, context_size,
return CrossEntropyLoss(forward.probs.Packed(), prompt, context_size,
vocab_size, pool);
}

View File

@ -17,9 +17,7 @@
#include "backprop/activations.h"
#include "backprop/prompt.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// Compiles this file for multiple architectures via "foreach_target.h", to

View File

@ -180,54 +180,59 @@ void ApplyLayer(const LayerWeightsPtrs<T>& weights,
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
static const T query_scale = T(1.0) / std::sqrt(T(qkv_dim));
RMSNormT(weights.pre_attention_norm_scale.data(), activations.input.data(),
activations.pre_att_rms_out.data(), model_dim, num_tokens);
RMSNormT(weights.pre_attention_norm_scale.Packed(),
activations.input.Packed(), activations.pre_att_rms_out.Packed(),
model_dim, num_tokens);
MatMulT(weights.qkv_einsum_w.data(), activations.pre_att_rms_out.data(),
activations.qkv.data(), (heads + 2) * qkv_dim, model_dim, num_tokens);
MatMulT(weights.qkv_einsum_w.Packed(), activations.pre_att_rms_out.Packed(),
activations.qkv.Packed(), (heads + 2) * qkv_dim, model_dim,
num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
for (size_t h = 0; h <= heads; ++h) {
Rope(qkv + h * qkv_dim, qkv_dim, pos);
}
}
for (size_t pos = 0; pos < num_tokens; ++pos) {
T* qkv = activations.qkv.data() + pos * (heads + 2) * qkv_dim;
T* qkv = activations.qkv.Packed() + pos * (heads + 2) * qkv_dim;
MulByConstT(query_scale, qkv, heads * qkv_dim);
}
MaskedAttention(activations.qkv.data(), activations.att.data(), num_tokens,
heads, qkv_dim, seq_len);
MaskedAttention(activations.qkv.Packed(), activations.att.Packed(),
num_tokens, heads, qkv_dim, seq_len);
MaskedSoftmax(activations.att.data(), num_tokens, heads, seq_len);
MaskedSoftmax(activations.att.Packed(), num_tokens, heads, seq_len);
MixByAttention(activations.qkv.data(), activations.att.data(),
activations.att_out.data(), num_tokens, heads, qkv_dim,
MixByAttention(activations.qkv.Packed(), activations.att.Packed(),
activations.att_out.Packed(), num_tokens, heads, qkv_dim,
seq_len);
MultiHeadMatMul(weights.attn_vec_einsum_w.data(), activations.att_out.data(),
activations.attention_out.data(), heads, model_dim, qkv_dim,
MultiHeadMatMul(weights.attn_vec_einsum_w.Packed(),
activations.att_out.Packed(),
activations.attention_out.Packed(), heads, model_dim, qkv_dim,
num_tokens);
AddFromT(activations.input.data(), activations.attention_out.data(),
AddFromT(activations.input.Packed(), activations.attention_out.Packed(),
num_tokens * model_dim);
RMSNormT(weights.pre_ffw_norm_scale.data(), activations.attention_out.data(),
activations.bf_pre_ffw_rms_out.data(), model_dim, num_tokens);
RMSNormT(weights.pre_ffw_norm_scale.Packed(),
activations.attention_out.Packed(),
activations.bf_pre_ffw_rms_out.Packed(), model_dim, num_tokens);
MatMulT(weights.gating_einsum_w.data(), activations.bf_pre_ffw_rms_out.data(),
activations.ffw_hidden.data(), ff_hidden_dim * 2, model_dim,
MatMulT(weights.gating_einsum_w.Packed(),
activations.bf_pre_ffw_rms_out.Packed(),
activations.ffw_hidden.Packed(), ff_hidden_dim * 2, model_dim,
num_tokens);
GatedGelu(activations.ffw_hidden.data(), activations.ffw_hidden_gated.data(),
ff_hidden_dim, num_tokens);
GatedGelu(activations.ffw_hidden.Packed(),
activations.ffw_hidden_gated.Packed(), ff_hidden_dim, num_tokens);
MatMulT(weights.linear_w.data(), activations.ffw_hidden_gated.data(), output,
model_dim, ff_hidden_dim, num_tokens);
MatMulT(weights.linear_w.Packed(), activations.ffw_hidden_gated.Packed(),
output, model_dim, ff_hidden_dim, num_tokens);
AddFromT(activations.attention_out.data(), output, num_tokens * model_dim);
AddFromT(activations.attention_out.Packed(), output, num_tokens * model_dim);
}
template<typename T>
@ -258,35 +263,35 @@ T CrossEntropyLossForwardPass(const Prompt& prompt,
const size_t num_tokens = tokens.empty() ? 0 : tokens.size() - 1;
const T kEmbScaling = EmbeddingScaling(model_dim);
InputEmbedding(weights.embedder_input_embedding.data(), tokens, kEmbScaling,
forward.layers[0].input.data(), model_dim);
InputEmbedding(weights.embedder_input_embedding.Packed(), tokens, kEmbScaling,
forward.layers[0].input.Packed(), model_dim);
for (size_t layer = 0; layer < layers; ++layer) {
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.data()
: forward.final_layer_output.data();
T* output = layer + 1 < layers ? forward.layers[layer + 1].input.Packed()
: forward.final_layer_output.Packed();
ApplyLayer(*weights.GetLayer(layer), forward.layers[layer], num_tokens,
output);
}
RMSNormT(weights.final_norm_scale.data(), forward.final_layer_output.data(),
forward.final_norm_output.data(), model_dim, num_tokens);
RMSNormT(weights.final_norm_scale.Packed(),
forward.final_layer_output.Packed(),
forward.final_norm_output.Packed(), model_dim, num_tokens);
MatMulT(weights.embedder_input_embedding.data(),
forward.final_norm_output.data(), forward.logits.data(), vocab_size,
model_dim, num_tokens);
MatMulT(weights.embedder_input_embedding.Packed(),
forward.final_norm_output.Packed(), forward.logits.Packed(),
vocab_size, model_dim, num_tokens);
for (size_t pos = 0; pos < num_tokens; ++pos) {
if (config.final_cap > 0.0f) {
Softcap(config.final_cap, forward.logits.data() + pos * vocab_size,
Softcap(config.final_cap, forward.logits.Packed() + pos * vocab_size,
vocab_size);
}
}
memcpy(forward.probs.data(), forward.logits.data(),
num_tokens * vocab_size * sizeof(forward.logits.At(0)));
Softmax(forward.probs.data(), vocab_size, num_tokens);
CopyMat(forward.logits, forward.probs);
Softmax(forward.probs.Packed(), vocab_size, num_tokens);
return CrossEntropyLoss(forward.probs.data(), prompt, vocab_size);
return CrossEntropyLoss(forward.probs.Packed(), prompt, vocab_size);
}
} // namespace gcpp

View File

@ -41,11 +41,14 @@
namespace gcpp {
TEST(OptimizeTest, GradientDescent) {
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1));
Allocator::Init(topology);
NestedPools pools(topology, 1, /*pin=*/Tristate::kFalse);
MatMulEnv env(topology, pools);
hwy::ThreadPool& pool = pools.Pool();
gcpp::ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.pin = Tristate::kFalse;
ThreadingContext2::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get());
const Allocator2& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool();
std::mt19937 gen(42);
const ModelInfo info = {
@ -64,7 +67,7 @@ TEST(OptimizeTest, GradientDescent) {
KVCache kv_cache = KVCache::Create(config, /*prefill_tbatch_size=*/16);
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
config.layer_configs[0].qkv_dim,
allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
Gemma gemma(GemmaTokenizer(), info, env);

View File

@ -18,9 +18,9 @@
#include <cmath>
#include "compression/compress.h"
#include "gemma/common.h"
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -39,11 +39,11 @@ class AdamUpdater {
void operator()(const char* name, const MatPtr& grad, MatPtr& weights,
MatPtr& grad_m, MatPtr& grad_v) {
const float* HWY_RESTRICT g = grad.data<float>();
float* HWY_RESTRICT w = weights.data<float>();
float* HWY_RESTRICT m = grad_m.data<float>();
float* HWY_RESTRICT v = grad_v.data<float>();
for (size_t i = 0; i < grad.NumElements(); ++i) {
const float* HWY_RESTRICT g = grad.RowT<float>(0);
float* HWY_RESTRICT w = weights.RowT<float>(0);
float* HWY_RESTRICT m = grad_m.RowT<float>(0);
float* HWY_RESTRICT v = grad_v.RowT<float>(0);
for (size_t i = 0; i < grad.Extents().Area(); ++i) {
m[i] *= beta1_;
m[i] += cbeta1_ * g[i];
v[i] *= beta2_;

View File

@ -24,21 +24,13 @@
#include <vector>
#include "gtest/gtest.h"
#include "compression/compress.h"
#include "gemma/configs.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
template <typename T>
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
std::normal_distribution<T> dist(0.0, stddev);
for (size_t i = 0; i < x.NumElements(); ++i) {
x.At(i) = dist(gen);
}
}
// TODO: make a member of Layer<T>.
template <typename T>
void RandInit(LayerWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
@ -62,8 +54,12 @@ void RandInit(ModelWeightsPtrs<T>& w, T stddev, std::mt19937& gen) {
template <typename T, typename U>
void Complexify(const MatPtrT<T>& x, MatPtrT<std::complex<U>>& c_x) {
for (size_t i = 0; i < x.NumElements(); ++i) {
c_x.At(i) = std::complex<U>(x.At(i), 0.0);
for (size_t r = 0; r < x.Rows(); ++r) {
const T* row = x.Row(r);
std::complex<U>* c_row = c_x.Row(r);
for (size_t c = 0; c < x.Cols(); ++c) {
c_row[c] = std::complex<U>(row[c], 0.0);
}
}
}
@ -87,14 +83,14 @@ void Complexify(const ModelWeightsPtrs<T>& w, ModelWeightsPtrs<U>& c_w) {
}
}
// Somewhat duplicates ModelWeightsStorage, but that has neither double nor
// Somewhat duplicates WeightsOwner, but that has neither double nor
// complex types allowed and it would cause code bloat to add them there.
template <typename T>
class WeightsWrapper {
public:
explicit WeightsWrapper(const ModelConfig& config)
: pool_(0), weights_(config) {
weights_.Allocate(data_, pool_);
weights_.Allocate(owners_, pool_);
}
const ModelWeightsPtrs<T>& get() const { return weights_; }
@ -106,7 +102,7 @@ class WeightsWrapper {
private:
hwy::ThreadPool pool_;
std::vector<MatStorage> data_;
std::vector<MatOwner> owners_;
ModelWeightsPtrs<T> weights_;
};
@ -116,13 +112,18 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
double sum0 = 0;
double sum1 = 0;
double sum01 = 0;
for (size_t i = 0; i < actual.NumElements(); ++i) {
sum0 += actual.At(i) * actual.At(i);
sum1 += expected.At(i) * expected.At(i);
sum01 += actual.At(i) * expected.At(i);
ASSERT_NEAR(actual.At(i), expected.At(i),
std::max(max_abs_err, std::abs(expected.At(i)) * max_rel_err))
<< "line: " << line << " dim=" << expected.NumElements() << " i=" << i;
for (size_t r = 0; r < actual.Rows(); ++r) {
const T* actual_row = actual.Row(r);
const U* expected_row = expected.Row(r);
for (size_t c = 0; c < actual.Cols(); ++c) {
sum0 += actual_row[c] * actual_row[c];
sum1 += expected_row[c] * expected_row[c];
sum01 += actual_row[c] * expected_row[c];
ASSERT_NEAR(
actual_row[c], expected_row[c],
std::max(max_abs_err, std::abs(expected_row[c]) * max_rel_err))
<< "line: " << line << " r " << r << " c " << c;
}
}
if (sum0 > 1e-40) {
double norm_dot = sum01 / std::sqrt(sum0) / std::sqrt(sum1);
@ -148,15 +149,19 @@ void TestNear(const MatPtrT<T>& actual, const MatPtrT<U>& expected,
template <typename FUNC, typename T, typename U>
void TestGradient(const MatPtrT<T>& grad, MatPtrT<std::complex<U>>& x,
FUNC func, U step, T max_abs_err, T max_rel_err, int line) {
MatStorageT<T> exp_grad("exp_grad", x.Rows(), x.Cols());
MatStorageT<T> exp_grad = MakePacked<T>("exp_grad", x.Rows(), x.Cols());
const U inv_step = 1.0 / step;
for (size_t i = 0; i < x.NumElements(); ++i) {
const U x0 = std::real(x.At(i));
const std::complex<U> x1 = std::complex<U>(x0, step);
x.At(i) = x1;
const std::complex<U> f1 = func();
exp_grad.At(i) = std::imag(f1) * inv_step;
x.At(i) = x0;
for (size_t r = 0; r < x.Rows(); ++r) {
std::complex<U>* x_row = x.Row(r);
T* exp_row = exp_grad.Row(r);
for (size_t c = 0; c < x.Cols(); ++c) {
const U x0 = std::real(x_row[c]);
const std::complex<U> x1 = std::complex<U>(x0, step);
x_row[c] = x1;
const std::complex<U> f1 = func();
exp_row[c] = std::imag(f1) * inv_step;
x_row[c] = x0;
}
}
TestNear(grad, exp_grad, max_abs_err, max_rel_err, line);
}

View File

@ -146,6 +146,7 @@ cc_library(
":distortion",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:thread_pool",
],
)
@ -209,6 +210,7 @@ cc_library(
"//:allocator",
"//:basics",
"//:common",
"//:mat",
"@highway//:hwy",
"@highway//:nanobenchmark",
"@highway//:stats",
@ -252,22 +254,6 @@ cc_library(
],
)
cc_binary(
name = "compress_weights",
srcs = ["compress_weights.cc"],
deps = [
":compress",
":io",
"//:allocator",
"//:args",
"//:common",
"//:tokenizer",
"//:weights",
"@highway//:hwy",
"@highway//:thread_pool",
],
)
cc_binary(
name = "blob_compare",
srcs = ["blob_compare.cc"],
@ -277,9 +263,11 @@ cc_binary(
"//:allocator",
"//:basics",
"//:threading",
"//:threading_context",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
"@highway//:thread_pool",
],
)
@ -287,7 +275,6 @@ cc_binary(
name = "migrate_weights",
srcs = ["migrate_weights.cc"],
deps = [
"//:app",
"//:args",
"//:benchmark_helper",
"//:gemma_lib",

View File

@ -25,8 +25,10 @@
#include "util/allocator.h"
#include "util/basics.h" // IndexRange
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
namespace gcpp {
@ -202,15 +204,13 @@ void ReadAndCompareBlobs(const char* path1, const char* path2) {
if (!CompareKeys(reader1, reader2)) return;
// Single allocation, avoid initializing the memory.
BoundedTopology topology;
Allocator::Init(topology);
NestedPools pools(topology);
const size_t total_bytes = TotalBytes(reader1) + TotalBytes(reader2);
BytePtr all_blobs = hwy::AllocateAligned<uint8_t>(total_bytes);
size_t pos = 0;
BlobVec blobs1 = ReserveMemory(reader1, all_blobs, pos);
BlobVec blobs2 = ReserveMemory(reader2, all_blobs, pos);
NestedPools& pools = ThreadingContext2::Get().pools;
ReadBothBlobs(reader1, reader2, total_bytes, blobs1, blobs2, pools);
CompareBlobs(reader1.Keys(), blobs1, blobs2, total_bytes, pools);

View File

@ -29,6 +29,7 @@
#include "compression/compress.h" // IWYU pragma: export
#include "compression/distortion.h"
#include "gemma/configs.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -379,7 +380,7 @@ struct CompressTraits<SfpStream> {
using Packed = SfpStream;
// Callers are responsible for scaling `raw` such that its magnitudes do not
// exceed `SfpStream::kMax`. See CompressedArray::scale().
// exceed `SfpStream::kMax`. See CompressedArray::Scale().
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT raw,
size_t num, CompressPerThread& tls,
@ -522,8 +523,7 @@ HWY_INLINE void CompressScaled(const float* HWY_RESTRICT raw, size_t num,
CompressWorkingSet& work,
MatStorageT<Packed>& compressed,
hwy::ThreadPool& pool) {
Compress(raw, num, work,
MakeSpan(compressed.data(), compressed.NumElements()),
Compress(raw, num, work, compressed.Span(),
/*packed_ofs=*/0, pool);
}
@ -717,11 +717,9 @@ class Compressor {
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name,
const float* HWY_RESTRICT weights) {
size_t num_weights = compressed->NumElements();
if (num_weights == 0 || weights == nullptr || compressed->Ptr() == nullptr)
return;
size_t num_compressed = compressed->NumElements();
PackedSpan<Packed> packed = MakeSpan(compressed->data(), num_compressed);
size_t num_weights = compressed->Extents().Area();
if (num_weights == 0 || weights == nullptr || !compressed->HasPtr()) return;
PackedSpan<Packed> packed = compressed->Span();
fprintf(stderr, "Compressing %s (%zuM), please wait\n", decorated_name,
num_weights / (1000 * 1000));
Compress(weights, num_weights, work_, packed, /*packed_ofs=*/0,

View File

@ -17,6 +17,6 @@
namespace gcpp {
MatPtr::~MatPtr() {}
// TODO: move ScaleWeights here.
} // namespace gcpp

View File

@ -41,7 +41,8 @@
// IWYU pragma: end_exports
#include "gemma/configs.h"
#include "util/allocator.h"
#include "hwy/per_target.h"
#include "util/mat.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#if COMPRESS_STATS
#include "compression/distortion.h"
#include "hwy/stats.h"
@ -49,322 +50,6 @@
namespace gcpp {
// Base class for rank-1 or 2 tensors (vector or matrix).
// Supports both dynamic and compile-time sizing.
// Holds metadata and a non-owning pointer to the data, owned by the derived
// MatStorageT class.
// This class also provides easy conversion from/to a table of contents for a
// BlobStore file, and a templated (compile-time) accessor for a 2-d array of
// fixed inner dimension and type.
// It is designed to be put in a vector, and has default copy and operator=, so
// it is easy to read/write a blob_store file.
class MatPtr : public IFields {
public:
// Full constructor for dynamic sizing.
MatPtr(const std::string& name, Type type, size_t element_size, size_t rows,
size_t cols)
: name_(name),
type_(type),
element_size_(element_size),
num_elements_(rows * cols),
rows_(rows),
cols_(cols),
ptr_(nullptr) {
stride_ = cols;
}
// Default is to leave all fields default-initialized.
MatPtr() = default;
virtual ~MatPtr();
// Compatibility interface for CompressedArray.
// TODO: remove.
template <typename T>
T* data() {
return HWY_RCAST_ALIGNED(T*, ptr_);
}
template <typename T>
const T* data() const {
return HWY_RCAST_ALIGNED(const T*, ptr_);
}
const void* Ptr() const { return ptr_; }
void* Ptr() { return ptr_; }
// Sets the pointer from another MatPtr.
void SetPtr(const MatPtr& other) { ptr_ = other.ptr_; }
// Copying allowed as the metadata is small.
MatPtr(const MatPtr& other) = default;
MatPtr& operator=(const MatPtr& other) = default;
// Returns the name of the blob.
const char* Name() const override { return name_.c_str(); }
void SetName(const std::string& name) { name_ = name; }
// Returns the type of the blob.
Type GetType() const { return type_; }
// Returns the size of each element in bytes.
size_t ElementSize() const { return element_size_; }
// Returns the number of elements in the array.
size_t NumElements() const { return num_elements_; }
// Returns the number of bytes in the array.
size_t SizeBytes() const {
if (this->GetType() == TypeEnum<NuqStream>()) {
return NuqStream::PackedEnd(num_elements_);
}
return num_elements_ * element_size_;
}
// Returns the number of rows in the 2-d array (outer dimension).
size_t Rows() const { return rows_; }
// Returns the number of columns in the 2-d array (inner dimension).
size_t Cols() const { return cols_; }
Extents2D Extents() const { return Extents2D(rows_, cols_); }
// Currently same as cols, but may differ in the future. This is the offset by
// which to advance pointers to the next row.
size_t Stride() const { return stride_; }
// Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range
// of magnitudes.
float scale() const { return scale_; }
void set_scale(float scale) { scale_ = scale; }
std::string LayerName(int layer) const {
std::string name = name_ + std::to_string(layer);
HWY_ASSERT(name.size() <= sizeof(hwy::uint128_t));
return name;
}
// Sets all data to zero.
void ZeroInit() {
if (ptr_ == nullptr)
HWY_ABORT("ptr_ is null on tensor %s\n", name_.c_str());
hwy::ZeroBytes(ptr_, SizeBytes());
}
void VisitFields(IFieldsVisitor& visitor) override {
visitor(name_);
visitor(type_);
visitor(element_size_);
visitor(num_elements_);
visitor(rows_);
visitor(cols_);
visitor(scale_);
visitor(stride_);
}
// Calls func on the upcasted type. Since MatPtr by design is not templated,
// here we provide a way to get to the derived type, provided that `Type()`
// is one of the strings returned by `TypeName()`.
template <class FuncT, typename... TArgs>
decltype(auto) CallUpcasted(FuncT& func, TArgs&&... args);
protected:
// Arbitrary name for the array of preferably <= 16 characters.
std::string name_;
// Should be the result of TypeEnum<T> for CallUpcasted() to work.
Type type_;
// sizeof(T)
uint32_t element_size_ = 0;
// Number of elements in the array.
uint32_t num_elements_ = 0; // In element_size units.
// Number of rows in the 2-d array (outer dimension).
uint32_t rows_ = 0;
// Number of columns in the 2-d array (inner dimension).
uint32_t cols_ = 0;
// Scaling to apply to each element.
float scale_ = 1.0f;
// Aligned data array. This is always a borrowed pointer. It should never be
// freed. The underlying memory is owned by a subclass or some external class
// and must outlive this object.
void* ptr_ = nullptr;
uint32_t stride_;
};
// MatPtrT adds a single template argument to MatPtr for an explicit type.
// Use this class as a function argument where the type needs to be known.
// Use MatPtr where the type does not need to be known.
template <typename MatT>
class MatPtrT : public MatPtr {
public:
// Full constructor for dynamic sizing.
MatPtrT(const std::string& name, size_t rows, size_t cols)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), rows, cols) {}
// Construction from TensorIndex entry to remove duplication of sizes.
MatPtrT(const std::string& name, const TensorIndex& tensor_index)
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
MatPtrT(const std::string& name, const TensorInfo* tensor)
: MatPtr(name, TypeEnum<MatT>(), sizeof(MatT), 0, 0) {
if (tensor == nullptr) {
cols_ = 0;
rows_ = 0;
} else {
cols_ = tensor->shape.back();
rows_ = 1;
if (tensor->cols_take_extra_dims) {
// The columns eat the extra dimensions.
rows_ = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols_ *= tensor->shape[i];
}
} else {
// The rows eat the extra dimensions.
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows_ *= tensor->shape[i];
}
}
}
stride_ = cols_;
num_elements_ = rows_ * cols_;
}
// Copying allowed as the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {}
MatPtrT& operator=(const MatPtr& other) {
MatPtr::operator=(other);
return *this;
}
MatPtrT(const MatPtrT& other) = default;
MatPtrT& operator=(const MatPtrT& other) = default;
std::string CacheName(int layer = -1, char separator = ' ',
int index = -1) const {
// Already used/retired: s, S, n, 1
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
: hwy::IsSame<MatT, BF16>() ? 'B'
: hwy::IsSame<MatT, SfpStream>() ? '$'
: hwy::IsSame<MatT, NuqStream>() ? '2'
: '?';
std::string name = std::string(1, prefix) + name_;
if (layer >= 0 || index >= 0) {
name += '_';
if (layer >= 0) name += std::to_string(layer);
if (index >= 0) {
name += separator + std::to_string(index);
}
}
return name;
}
// Sets the number of elements in the array. For use when the number of
// elements is != rows * cols ONLY.
void SetNumElements(size_t num_elements) {
num_elements_ = CompressedArrayElements<MatT>(num_elements);
}
// 2-d Accessor for a specific type but with a dynamic inner dimension.
template <typename T = MatT>
const T& At(size_t row, size_t col) const {
size_t index = row * cols_ + col;
HWY_DASSERT(index < num_elements_);
return HWY_RCAST_ALIGNED(const T*, ptr_)[index];
}
// 1-d Accessor for a specific type.
// TODO: replace this with a Foreach(), or at least a ForEachRow().
const MatT& At(size_t index) const {
HWY_DASSERT(index < num_elements_);
return HWY_RCAST_ALIGNED(const MatT*, ptr_)[index];
}
MatT& At(size_t index) { return HWY_RCAST_ALIGNED(MatT*, ptr_)[index]; }
// Compatibility interface for CompressedArray.
// TODO: remove
template <typename T = MatT>
T* data() {
return HWY_RCAST_ALIGNED(T*, ptr_);
}
template <typename T = MatT>
const T* data() const {
return HWY_RCAST_ALIGNED(const T*, ptr_);
}
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
// A scale of 0 indicates that the scale has likely never been set, so is
// "implicitly 1".
const MatT* data_scale1() const {
HWY_ASSERT(scale() == 1.f);
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
}
};
template <class FuncT, typename... TArgs>
decltype(auto) MatPtr::CallUpcasted(FuncT& func, TArgs&&... args) {
if (type_ == TypeEnum<float>()) {
return func(dynamic_cast<MatPtrT<float>*>(this),
std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<BF16>()) {
return func(dynamic_cast<MatPtrT<BF16>*>(this),
std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<SfpStream>()) {
return func(dynamic_cast<MatPtrT<SfpStream>*>(this),
std::forward<TArgs>(args)...);
} else if (type_ == TypeEnum<NuqStream>()) {
return func(dynamic_cast<MatPtrT<NuqStream>*>(this),
std::forward<TArgs>(args)...);
} else {
HWY_ABORT("Type %d unknown.", type_);
}
}
// MatStorageT adds the actual data storage to MatPtrT.
// TODO: use Extents2D instead of rows and cols.
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
// Full constructor for dynamic sizing.
MatStorageT(const std::string& name, size_t rows, size_t cols)
: MatPtrT<MatT>(name, rows, cols) {
Allocate();
}
// Can copy the metadata, from a MatPtr, and allocate later.
MatStorageT(const MatPtr& other) : MatPtrT<MatT>(other) {}
~MatStorageT() = default;
// Move-only because this contains a unique_ptr.
MatStorageT(const MatStorageT& other) = delete;
MatStorageT& operator=(const MatStorageT& other) = delete;
MatStorageT(MatStorageT&& other) = default;
MatStorageT& operator=(MatStorageT&& other) = default;
// Allocate the memory and copy the pointer to the MatPtr.
// num_elements is in elements. In the default (zero) case, it is computed
// from the current num_elements_ which was set by the constructor from the
// rows and cols.
void Allocate(size_t num_elements = 0) {
if (num_elements == 0) {
num_elements = hwy::DivCeil(this->SizeBytes(), sizeof(MatT));
} else {
this->num_elements_ = num_elements;
}
// Pad to allow overrunning the last row by 2 BF16 vectors, hence at most
// `2 * VectorBytes / sizeof(BF16)` elements of MatT.
const size_t padding = hwy::VectorBytes();
data_ = Allocator::Alloc<MatT>(num_elements + padding);
hwy::ZeroBytes(&data_[num_elements], padding * sizeof(MatT));
this->ptr_ = data_.get();
}
// Zeros the content.
void ZeroInit() {
HWY_ASSERT(data_ != nullptr);
hwy::ZeroBytes(data_.get(), this->SizeBytes());
}
private:
AlignedPtr<MatT> data_;
};
// MatStorage allows heterogeneous tensors to be stored in a single vector.
using MatStorage = MatStorageT<hwy::uint128_t>;
// Table of contents for a blob store file. Full metadata, but not actual data.
class BlobToc {
public:
@ -389,7 +74,7 @@ class BlobToc {
blob.Read(hwy::Span<const uint32_t>(toc), consumed);
prev_consumed = consumed;
consumed = result.pos;
if (blob.NumElements() > 0) {
if (!blob.IsEmpty()) {
AddToToc(blob);
}
}
@ -503,10 +188,11 @@ class WriteToBlobStore {
explicit WriteToBlobStore(hwy::ThreadPool& pool) : pool_(pool) {}
template <typename Packed>
void operator()(MatPtrT<Packed>* compressed, const char* decorated_name) {
if (compressed->Ptr() == nullptr) return;
writer_.Add(MakeKey(decorated_name), compressed->Ptr(),
compressed->SizeBytes());
void operator()(MatPtrT<Packed>* compressed,
const char* decorated_name) const {
if (!compressed->HasPtr()) return;
writer_.Add(MakeKey(decorated_name), compressed->Packed(),
compressed->PackedBytes());
MatPtr renamed_tensor(*compressed);
renamed_tensor.SetName(decorated_name);
renamed_tensor.AppendTo(toc_);
@ -519,9 +205,8 @@ class WriteToBlobStore {
void AddScales(const float* scales, size_t len) {
if (len) {
MatPtrT<float> scales_ptr("scales", 0, 1);
writer_.Add(MakeKey(scales_ptr.CacheName().c_str()), scales,
len * sizeof(scales[0]));
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
writer_.Add(MakeKey(scales_ptr.Name()), scales, len * sizeof(scales[0]));
}
}
@ -554,9 +239,9 @@ class WriteToBlobStore {
hwy::ThreadPool& pool_;
private:
std::vector<uint32_t> toc_;
BlobWriter writer_;
std::vector<uint32_t> config_buffer_;
mutable std::vector<uint32_t> toc_;
mutable BlobWriter writer_;
mutable std::vector<uint32_t> config_buffer_;
};
// Functor called for each tensor, which loads them and their scaling factors
@ -613,6 +298,7 @@ class ReadFromBlobStore {
// Called for each tensor, enqueues read requests.
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
if (file_toc_.Empty() || file_toc_.Contains(name)) {
HWY_ASSERT(tensors[0]);
model_toc_.push_back(tensors[0]);
file_keys_.push_back(name);
}
@ -622,15 +308,15 @@ class ReadFromBlobStore {
for (size_t i = 0; i < len; ++i) {
scales[i] = 1.0f;
}
MatPtrT<float> scales_ptr("scales", 0, 1);
auto key = MakeKey(scales_ptr.CacheName().c_str());
MatPtrT<float> scales_ptr("scales", Extents2D(0, 1));
auto key = MakeKey(scales_ptr.Name());
if (reader_.BlobSize(key) == 0) return 0;
return reader_.Enqueue(key, scales, len * sizeof(scales[0]));
}
// Returns whether all tensors are successfully loaded from cache.
BlobError ReadAll(hwy::ThreadPool& pool,
std::vector<MatStorage>& model_memory) {
std::vector<MatOwner>& model_memory) {
// reader_ invalid or any Enqueue failed
if (err_ != 0) return err_;
// Setup the model_memory.
@ -650,26 +336,27 @@ class ReadFromBlobStore {
}
std::string name = blob->Name();
*blob = *toc_blob;
blob->SetName(name);
blob->SetName(name.c_str());
}
model_memory.emplace_back(*blob);
model_memory.back().SetName(file_key);
model_memory.push_back(MatOwner());
}
// Allocate in parallel using the pool.
pool.Run(0, model_memory.size(),
[this, &model_memory](uint64_t task, size_t /*thread*/) {
model_memory[task].Allocate();
model_toc_[task]->SetPtr(model_memory[task]);
model_memory[task].AllocateFor(*model_toc_[task],
MatPadding::kPacked);
});
// Enqueue the read requests.
for (auto& blob : model_memory) {
err_ =
reader_.Enqueue(MakeKey(blob.Name()), blob.data(), blob.SizeBytes());
for (size_t b = 0; b < model_toc_.size(); ++b) {
err_ = reader_.Enqueue(MakeKey(file_keys_[b].c_str()),
model_toc_[b]->RowT<uint8_t>(0),
model_toc_[b]->PackedBytes());
if (err_ != 0) {
fprintf(stderr,
"Failed to read blob %s (error %d) of size %zu x %zu x %zu\n",
blob.Name(), err_, blob.Rows(), blob.Cols(),
blob.ElementSize());
fprintf(
stderr,
"Failed to read blob %s (error %d) of size %zu x %zu, type %d\n",
file_keys_[b].c_str(), err_, model_toc_[b]->Rows(),
model_toc_[b]->Cols(), static_cast<int>(model_toc_[b]->GetType()));
return err_;
}
}

View File

@ -1,286 +0,0 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Command line tool to create compressed weights.
// Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'.
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
"compression/compress_weights.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "gemma/configs.h"
#include "gemma/tokenizer.h"
#ifndef GEMMA_COMPRESS_WEIGHTS_ONCE
#define GEMMA_COMPRESS_WEIGHTS_ONCE
#include <stddef.h>
#include <stdio.h>
#include <algorithm> // std::clamp
#include <cstdlib>
#include <iostream>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "compression/compress.h"
#include "compression/io.h" // Path
#include "compression/shared.h" // PromptWrapping
#include "gemma/common.h" // Model
#include "gemma/weights.h"
#include "util/allocator.h"
#include "util/args.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
namespace {
} // namespace
struct Args : public ArgsBase<Args> {
static constexpr size_t kDefaultNumThreads = ~size_t{0};
void ChooseNumThreads() {
if (num_threads == kDefaultNumThreads) {
// This is a rough heuristic, replace with something better in the future.
num_threads = static_cast<size_t>(std::clamp(
static_cast<int>(std::thread::hardware_concurrency()) - 2, 1, 18));
}
}
public:
Args(int argc, char* argv[]) {
InitAndParse(argc, argv);
ChooseNumThreads();
}
// Returns error string or nullptr if OK.
const char* Validate() {
if (const char* err = ParseModelTypeAndWrapping(model_type_str, model_type_,
prompt_wrapping_)) {
return err;
}
if (const char* err = ParseType(weight_type_str, weight_type_)) {
return err;
}
if (weights.path.empty()) {
return "Missing --weights flag, a file for the uncompressed model.";
}
if (compressed_weights.path.empty()) {
return "Missing --compressed_weights flag, a file for the compressed "
"model.";
}
if (!weights.Exists()) {
return "Can't open file specified with --weights flag.";
}
return nullptr;
}
Path weights; // uncompressed weights file location
Path compressed_weights; // compressed weights file location
std::string model_type_str;
std::string weight_type_str;
size_t num_threads;
// If non-empty, whether to include the config and TOC in the output file, as
// well as the tokenizer.
Path tokenizer;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(weights, "weights", Path(),
"Path to model weights (.bin) file.\n"
" Required argument.");
visitor(model_type_str, "model", std::string(),
"Model type\n 2b-it = 2B parameters, instruction-tuned\n "
"2b-pt = 2B parameters, pretrained\n 7b-it = 7B parameters "
"instruction-tuned\n 7b-pt = 7B parameters, pretrained\n "
"gr2b-it = griffin 2B parameters, instruction-tuned\n "
"gr2b-pt = griffin 2B parameters, pretrained\n "
" Required argument.");
visitor(weight_type_str, "weight_type", std::string("sfp"),
"Weight type\n f32 = float, bf16 = bfloat16, SFP = 8-bit FP\n"
" Required argument.");
visitor(compressed_weights, "compressed_weights", Path(),
"Path name where compressed weights (.sbs) file will be written.\n"
" Required argument.");
visitor(num_threads, "num_threads",
kDefaultNumThreads, // see ChooseNumThreads
"Number of threads to use.\n Default = Estimate of the "
"number of supported concurrent threads.",
2);
visitor(tokenizer, "tokenizer", Path(),
"Path to tokenizer file. If given, the config and TOC are also "
"added to the output file.");
}
// Uninitialized before Validate, must call after that.
gcpp::Model ModelType() const { return model_type_; }
gcpp::PromptWrapping PromptWrappingType() const { return prompt_wrapping_; }
gcpp::Type WeightType() const { return weight_type_; }
private:
Model model_type_;
PromptWrapping prompt_wrapping_;
Type weight_type_;
};
void ShowHelp(gcpp::Args& args) {
std::cerr
<< "Usage:\n./compress_weights --weights <path to uncompressed weights> "
" --model <model type> --compressed_weights <output path>\n";
std::cerr << "\n*Arguments*\n\n";
args.Help();
std::cerr << "\n";
}
} // namespace gcpp
#endif // GEMMA_COMPRESS_WEIGHTS_ONCE
// SIMD code, compiled once per target.
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
template <typename T>
void CompressWeights(const Path& weights_path,
const Path& compressed_weights_path, Model model_type,
Type weight_type, PromptWrapping wrapping,
const Path& tokenizer_path, hwy::ThreadPool& pool) {
if (!weights_path.Exists()) {
HWY_ABORT("The model weights file '%s' does not exist.",
weights_path.path.c_str());
}
printf("Compressing weights from %s to %s\n", weights_path.path.c_str(),
compressed_weights_path.path.c_str());
ModelConfig config = ConfigFromModel(model_type);
config.weight = weight_type;
config.wrapping = wrapping;
std::vector<MatStorage> model_storage;
ModelWeightsPtrs<T> c_weights(config);
c_weights.Allocate(model_storage, pool);
ModelWeightsPtrs<float> uc_weights(config);
uc_weights.Allocate(model_storage, pool);
// Get uncompressed weights, compress, and store.
FILE* fptr = fopen(weights_path.path.c_str(), "rb");
if (fptr == nullptr) {
HWY_ABORT("Failed to open model file %s - does it exist?",
weights_path.path.c_str());
}
bool ok = true;
uint64_t total_size = 0;
ModelWeightsPtrs<float>::ForEachTensor(
{&uc_weights}, ForEachType::kLoadNoToc,
[&](const char* name, hwy::Span<MatPtr*> tensors) {
fprintf(stderr, "Loading Parameters (size %zu): %s\n",
tensors[0]->SizeBytes(), name);
ok &= 1 == fread(tensors[0]->Ptr(), tensors[0]->SizeBytes(), 1, fptr);
total_size += tensors[0]->SizeBytes();
});
if (!tokenizer_path.path.empty()) {
uc_weights.AllocAndCopyWithTranspose(pool, model_storage);
}
const bool scale_for_compression = config.num_tensor_scales > 0;
std::vector<float> scales;
if (scale_for_compression) {
uc_weights.GetOrApplyScales(scales);
}
Compressor compressor(pool);
ModelWeightsPtrs<T>::ForEachTensor(
{reinterpret_cast<ModelWeightsPtrs<T>*>(&uc_weights), &c_weights},
tokenizer_path.path.empty() ? ForEachType::kLoadNoToc
: ForEachType::kLoadWithToc,
[&compressor](const char* name, hwy::Span<MatPtr*> tensors) {
tensors[1]->CallUpcasted(
compressor, name,
reinterpret_cast<const float*>(tensors[0]->Ptr()));
});
if (!tokenizer_path.path.empty()) {
std::string tokenizer_proto = ReadFileToString(tokenizer_path);
compressor.AddTokenizer(tokenizer_proto);
} else {
compressor.AddScales(scales.data(), scales.size() * sizeof(scales[0]));
}
compressor.WriteAll(compressed_weights_path,
tokenizer_path.path.empty() ? nullptr : &config);
}
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
void Run(Args& args) {
hwy::ThreadPool pool(args.num_threads);
if (args.PromptWrappingType() == PromptWrapping::PALIGEMMA) {
HWY_ABORT("PaliGemma is not supported in compress_weights.");
}
const Model model_type = args.ModelType();
const Type weight_type = args.WeightType();
switch (weight_type) {
case Type::kF32:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<float>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kBF16:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<BF16>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kSFP:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<SfpStream>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
case Type::kNUQ:
HWY_EXPORT_AND_DYNAMIC_DISPATCH_T(CompressWeights<NuqStream>)
(args.weights, args.compressed_weights, model_type, weight_type,
args.PromptWrappingType(), args.tokenizer, pool);
break;
default:
HWY_ABORT("Weight type %d unsupported.", static_cast<int>(weight_type));
}
}
} // namespace gcpp
int main(int argc, char** argv) {
gcpp::Args args(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
gcpp::ShowHelp(args);
return 0;
}
if (const char* error = args.Validate()) {
gcpp::ShowHelp(args);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(args);
return 0;
}
#endif // HWY_ONCE

View File

@ -57,6 +57,6 @@ int main(int argc, char** argv) {
}
gcpp::GemmaEnv env(argc, argv);
hwy::ThreadPool pool(0);
env.GetModel()->Save(args.output_weights, pool);
env.GetGemma()->Save(args.output_weights, pool);
return 0;
}

View File

@ -16,7 +16,9 @@ cc_library(
deps = [
"@abseil-cpp//absl/types:span",
"//:common",
"//:mat",
"//:tokenizer",
"//:weights",
"//compression:compress",
"//compression:io",
"@highway//:hwy",
@ -30,7 +32,6 @@ pybind_extension(
deps = [
":compression_clif_aux",
"@abseil-cpp//absl/types:span",
"//:common",
"//compression:sfp",
],
)

View File

@ -22,7 +22,8 @@
#include "compression/compress.h"
#include "compression/shared.h"
#include "hwy/aligned_allocator.h"
#include "gemma/weights.h"
#include "util/mat.h"
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE \
@ -81,30 +82,23 @@ class SbsWriterImpl : public WriterInterface {
template <typename Packed>
void AllocateAndCompress(const std::string& name,
absl::Span<const float> weights) {
MatPtrT<Packed> storage(name, 1, weights.size());
model_memory_.push_back(storage);
model_memory_.back().Allocate();
storage.SetPtr(model_memory_.back());
std::string decorated_name = storage.CacheName();
MatPtrT<Packed> storage(name.c_str(), Extents2D(1, weights.size()));
model_memory_.push_back(MatOwner());
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
std::string decorated_name = CacheName(storage);
compressor_(&storage, decorated_name.c_str(), weights.data());
}
template <typename Packed>
void AllocateWithShape(const std::string& name,
absl::Span<const float> weights,
const TensorInfo& tensor_info, float scale) {
MatPtrT<Packed> storage(name, &tensor_info);
storage.set_scale(scale);
MatPtrT<Packed> storage(name.c_str(), &tensor_info);
storage.SetScale(scale);
// Don't reset num_elements for NUQ.
if (!hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {
storage.SetNumElements(CompressedArrayElements<Packed>(weights.size()));
}
model_memory_.push_back(storage);
model_memory_.push_back(MatOwner());
if (mode_ == CompressorMode::kTEST_ONLY) return;
model_memory_.back().Allocate();
storage.SetPtr(model_memory_.back());
std::string decorated_name = storage.CacheName();
model_memory_.back().AllocateFor(storage, MatPadding::kPacked);
std::string decorated_name = CacheName(storage);
compressor_(&storage, decorated_name.c_str(), weights.data());
}
@ -176,7 +170,7 @@ class SbsWriterImpl : public WriterInterface {
hwy::ThreadPool pool_;
Compressor compressor_;
CompressWorkingSet working_set_;
std::vector<MatStorage> model_memory_;
std::vector<MatOwner> model_memory_;
std::vector<float> scales_;
CompressorMode mode_;
};

View File

@ -201,15 +201,24 @@ inline bool EnumValid(PromptWrapping type) {
// Tensor types for loading weights. Note that not all types are supported as
// weights for a model, but can be used for other purposes, such as types for
// ModelWeightsPtrs. When adding a new type that is supported, also
// `WeightsPtrs`. When adding a new type that is supported, also
// update gemma.cc, weights.*, and add instantiations/new_one.cc.
enum class Type { kUnknown, kF32, kBF16, kSFP, kNUQ, kF64, kC64, kU128 };
constexpr const char* kTypeStrings[] = {"unknown", "f32", "bf16", "sfp",
"nuq", "f64", "c64", "u128"};
static constexpr const char* kTypeStrings[] = {
"unknown", "f32", "bf16", "sfp", "nuq", "f64", "c64", "u128"};
static constexpr size_t kNumTypes =
sizeof(kTypeStrings) / sizeof(kTypeStrings[0]);
static constexpr size_t kTypeBits[] = {0,
8 * sizeof(float),
8 * sizeof(BF16),
8 * sizeof(SfpStream),
4 /* NuqStream, actually 4.5 */,
8 * sizeof(double),
8 * sizeof(std::complex<double>),
8 * sizeof(hwy::uint128_t)};
inline bool EnumValid(Type type) {
return static_cast<int>(type) >= 0 &&
static_cast<int>(type) <= static_cast<int>(Type::kU128);
static inline bool EnumValid(Type type) {
return static_cast<size_t>(type) < kNumTypes;
}
// Returns a Type enum for the type of the template parameter.
@ -236,10 +245,16 @@ Type TypeEnum() {
}
}
// Returns a string name for the type of the template parameter.
static inline size_t TypeBits(Type type) {
return kTypeBits[static_cast<int>(type)];
}
static inline const char* TypeName(Type type) {
return kTypeStrings[static_cast<int>(type)];
}
template <typename PackedT>
const char* TypeName() {
return kTypeStrings[static_cast<int>(TypeEnum<PackedT>())];
return TypeName(TypeEnum<PackedT>());
}
template <typename Packed>
@ -248,7 +263,9 @@ constexpr bool IsCompressed() {
}
// Returns the number of `MatT` elements required to store `capacity` values,
// which must not be zero.
// which must not be zero. This is only intended to support the extra tables
// required for NUQ. `capacity` includes any padding and is `rows * stride`.
// Deprecated, replaced by fixup within `MatPtr`. Only used by tests.
template <typename Packed>
constexpr size_t CompressedArrayElements(size_t capacity) {
if constexpr (hwy::IsSame<hwy::RemoveCvRef<Packed>, NuqStream>()) {

View File

@ -18,10 +18,13 @@
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
// IWYU pragma: begin_exports
#include "compression/compress.h"
#include "compression/distortion.h"
#include "util/mat.h"
// IWYU pragma: end_exports
#include "compression/compress.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_TEST_UTIL_INL_H_
// Include guard for (potentially) SIMD code.
@ -62,6 +65,52 @@ void ForeachPackedAndRawType() {
ForeachRawType<NuqStream, TestT>();
}
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStorageT<MatT> GenerateMat(const Extents2D& extents, hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
MatStorageT<MatT> compressed("mat", extents, MatPadding::kPacked);
const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
row[c] = f;
}
});
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
/*packed_ofs=*/0, pool);
compressed.SetScale(0.6f); // Arbitrary value, different from 1.
return compressed;
}
// `extents` describes the transposed matrix.
template <typename MatT>
MatStorageT<MatT> GenerateTransposedMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
MatStorageT<float> raw("raw", extents, MatPadding::kPacked);
MatStorageT<MatT> compressed("trans", extents, MatPadding::kPacked);
const float scale = SfpStream::kMax / extents.Area();
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
float* HWY_RESTRICT row = raw.Row(r);
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
row[c] = f;
}
});
Compress(raw.Packed(), raw.Extents().Area(), ws, compressed.Span(),
/*packed_ofs=*/0, pool);
// Arbitrary value, different from 1, must match `GenerateMat`.
compressed.SetScale(0.6f);
return compressed;
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp

View File

@ -128,10 +128,10 @@ int BenchmarkCrossEntropy(GemmaEnv& env, const Path& text,
size_t num_tokens = std::min<size_t>(prompt.size() - pos, batch_tokens);
std::vector<int> prompt_slice(prompt.begin() + pos,
prompt.begin() + pos + num_tokens);
KVCache kv_cache = KVCache::Create(env.GetModel()->GetModelConfig(),
KVCache kv_cache = KVCache::Create(env.GetGemma()->GetModelConfig(),
env.MutableConfig().prefill_tbatch_size);
float entropy = ComputeCrossEntropy(
*env.GetModel(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
*env.GetGemma(), num_tokens, prompt_slice, kv_cache, env.Verbosity());
total_entropy += entropy;
LogSpeedStats(time_start, pos + num_tokens);
std::string text_slice = env.StringFromTokens(prompt_slice);
@ -186,8 +186,8 @@ int main(int argc, char** argv) {
if (!benchmark_args.goldens.Empty()) {
const std::string golden_path =
benchmark_args.goldens.path + "/" +
gcpp::ModelString(env.GetModel()->Info().model,
env.GetModel()->Info().wrapping) +
gcpp::ModelString(env.GetGemma()->Info().model,
env.GetGemma()->Info().wrapping) +
".txt";
return BenchmarkGoldens(env, golden_path);
} else if (!benchmark_args.summarize_text.Empty()) {

View File

@ -18,27 +18,20 @@
#include <stdio.h>
#include <time.h>
#include <cstdio>
#include <iostream>
#include <memory>
#include <ostream>
#include <random>
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/compress.h" // TypeName
#include "compression/shared.h" // TypeName
#include "evals/cross_entropy.h"
#include "gemma/common.h" // StringFromType
#include "gemma/gemma.h"
#include "gemma/kv_cache.h"
#include "util/app.h"
#include "util/args.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/topology.h"
#include "gemma/gemma_args.h"
#include "ops/matmul.h" // MatMulEnv
#include "util/threading_context.h"
#include "hwy/highway.h"
#include "hwy/per_target.h" // VectorBytes
#include "hwy/per_target.h" // DispatchedTarget
#include "hwy/timer.h"
namespace gcpp {
@ -54,11 +47,9 @@ void InitGenerator(const InferenceArgs& inference, std::mt19937& gen) {
}
}
GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app)
: topology_(CreateTopology(app)),
pools_(CreatePools(topology_, app)),
env_(topology_, pools_) {
GemmaEnv::GemmaEnv(const ThreadingArgs& threading_args,
const LoaderArgs& loader, const InferenceArgs& inference)
: env_(MakeMatMulEnv(threading_args)) {
InferenceArgs mutable_inference = inference;
AbortIfInvalidArgs(mutable_inference);
LoaderArgs mutable_loader = loader;
@ -67,10 +58,10 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
fprintf(stderr, "Skipping model load because: %s\n", err);
} else {
fprintf(stderr, "Loading model...\n");
model_ = AllocateGemma(mutable_loader, env_);
gemma_ = AllocateGemma(mutable_loader, env_);
// Only allocate one for starters because GenerateBatch might not be called.
kv_caches_.resize(1);
kv_caches_[0] = KVCache::Create(model_->GetModelConfig(),
kv_caches_[0] = KVCache::Create(gemma_->GetModelConfig(),
inference.prefill_tbatch_size);
}
InitGenerator(inference, gen_);
@ -78,24 +69,13 @@ GemmaEnv::GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
.max_generated_tokens = inference.max_generated_tokens,
.temperature = inference.temperature,
.gen = &gen_,
.verbosity = app.verbosity,
.verbosity = inference.verbosity,
};
}
// Internal init must run before the GemmaEnv ctor above, hence it cannot occur
// in the argv ctor below because its body runs *after* the delegating ctor.
// This helper function takes care of the init, and could be applied to any of
// the *Args classes, it does not matter which.
static AppArgs MakeAppArgs(int argc, char** argv) {
{ // So that indentation matches expectations.
// Placeholder for internal init, do not modify.
}
return AppArgs(argc, argv);
}
GemmaEnv::GemmaEnv(int argc, char** argv)
: GemmaEnv(LoaderArgs(argc, argv), InferenceArgs(argc, argv),
MakeAppArgs(argc, argv)) {}
: GemmaEnv(ThreadingArgs(argc, argv), LoaderArgs(argc, argv),
InferenceArgs(argc, argv)) {}
QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
QueryResult result;
@ -117,7 +97,7 @@ QueryResult GemmaEnv::QueryModel(const std::vector<int>& tokens) {
}
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
runtime_config_.batch_stream_token = batch_stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
return result;
}
@ -127,7 +107,7 @@ void GemmaEnv::QueryModel(
gcpp::TimingInfo timing_info { .verbosity = runtime_config_.verbosity };
const StreamFunc previous_stream_token = runtime_config_.stream_token;
runtime_config_.stream_token = stream_token;
model_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
gemma_->Generate(runtime_config_, tokens, /*start_pos=*/0, kv_caches_[0],
timing_info);
runtime_config_.stream_token = previous_stream_token;
}
@ -142,7 +122,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
int token, float) {
std::string token_text;
HWY_ASSERT(
model_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
gemma_->Tokenizer().Decode(std::vector<int>{token}, &token_text));
res[query_index].response.append(token_text);
res[query_index].tokens_generated += 1;
if (res[query_index].tokens_generated ==
@ -164,7 +144,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
}
for (size_t i = 1; i < num_queries; ++i) {
if (kv_caches_[i].seq_len == 0) {
kv_caches_[i] = KVCache::Create(model_->GetModelConfig(),
kv_caches_[i] = KVCache::Create(gemma_->GetModelConfig(),
runtime_config_.prefill_tbatch_size);
}
}
@ -172,7 +152,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
gcpp::TimingInfo timing_info = {.verbosity = runtime_config_.verbosity};
runtime_config_.batch_stream_token = batch_stream_token;
std::vector<size_t> queries_pos(num_queries, 0);
model_->GenerateBatch(runtime_config_, queries_prompt,
gemma_->GenerateBatch(runtime_config_, queries_prompt,
QueriesPos(queries_pos.data(), num_queries),
KVCaches(&kv_caches_[0], num_queries), timing_info);
return res;
@ -203,7 +183,7 @@ std::vector<QueryResult> GemmaEnv::BatchQueryModel(
float GemmaEnv::CrossEntropy(const std::string& input) {
std::vector<int> prompt = Tokenize(input);
prompt.insert(prompt.begin(), BOS_ID);
return ComputeCrossEntropy(*GetModel(), /*max_generated_tokens=*/3072, prompt,
return ComputeCrossEntropy(*GetGemma(), /*max_generated_tokens=*/3072, prompt,
MutableKVCache(),
/*verbosity=*/0) /
static_cast<int>(input.size());
@ -236,17 +216,36 @@ std::string CacheString() {
return buf;
}
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
const BoundedTopology& topology, NestedPools& pools) {
loader.Print(app.verbosity);
inference.Print(app.verbosity);
app.Print(app.verbosity);
static constexpr const char* CompiledConfig() {
if constexpr (HWY_IS_ASAN) {
return "asan";
} else if constexpr (HWY_IS_MSAN) {
return "msan";
} else if constexpr (HWY_IS_TSAN) {
return "tsan";
} else if constexpr (HWY_IS_HWASAN) {
return "hwasan";
} else if constexpr (HWY_IS_UBSAN) {
return "ubsan";
} else if constexpr (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
return "opt";
}
}
if (app.verbosity >= 2) {
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
InferenceArgs& inference) {
threading.Print(inference.verbosity);
loader.Print(inference.verbosity);
inference.Print(inference.verbosity);
if (inference.verbosity >= 2) {
time_t now = time(nullptr);
char* dt = ctime(&now); // NOLINT
char cpu100[100] = "unknown";
(void)hwy::platform::GetCpuString(cpu100);
const ThreadingContext2& ctx = ThreadingContext2::Get();
fprintf(stderr,
"Date & Time : %s" // dt includes \n
@ -254,16 +253,18 @@ void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
"CPU topology : %s, %s, %s\n"
"Instruction set : %s (%zu bits)\n"
"Compiled config : %s\n"
"Weight Type : %s\n"
"EmbedderInput Type : %s\n",
dt, cpu100, topology.TopologyString(), pools.PinString(),
"Memory MiB : %4zu, %4zu free\n"
"Weight Type : %s\n",
dt, cpu100, ctx.topology.TopologyString(), ctx.pools.PinString(),
CacheString().c_str(), hwy::TargetName(hwy::DispatchedTarget()),
hwy::VectorBytes() * 8, CompiledConfig(),
StringFromType(loader.Info().weight), TypeName<EmbedderInputT>());
ctx.allocator.VectorBytes() * 8, CompiledConfig(),
ctx.allocator.TotalMiB(), ctx.allocator.FreeMiB(),
StringFromType(loader.Info().weight));
}
}
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
InferenceArgs& inference) {
std::cerr
<< "\n\ngemma.cpp : a lightweight, standalone C++ inference engine\n"
"==========================================================\n\n"
@ -272,16 +273,16 @@ void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
" --tokenizer\n"
" --weights\n"
" --model,\n"
" or with the newer weights format, specify just:\n"
" or with the single-file weights format, specify just:\n"
" --weights\n";
std::cerr << "\n*Example Usage*\n\n./gemma --tokenizer tokenizer.spm "
"--weights 2b-it-sfp.sbs --model 2b-it\n";
std::cerr << "\n*Threading Arguments*\n\n";
threading.Help();
std::cerr << "\n*Model Loading Arguments*\n\n";
loader.Help();
std::cerr << "\n*Inference Arguments*\n\n";
inference.Help();
std::cerr << "\n*Application Arguments*\n\n";
app.Help();
std::cerr << "\n";
}

View File

@ -24,9 +24,9 @@
#include <vector>
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "ops/matmul.h"
#include "util/app.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/base.h"
namespace gcpp {
@ -46,8 +46,10 @@ class GemmaEnv {
public:
// Calls the other constructor with *Args arguments initialized from argv.
GemmaEnv(int argc, char** argv);
GemmaEnv(const LoaderArgs& loader, const InferenceArgs& inference,
const AppArgs& app);
GemmaEnv(const ThreadingArgs& threading, const LoaderArgs& loader,
const InferenceArgs& inference);
MatMulEnv& Env() { return env_; }
size_t MaxGeneratedTokens() const {
return runtime_config_.max_generated_tokens;
@ -58,7 +60,7 @@ class GemmaEnv {
std::vector<int> Tokenize(const std::string& input) const {
std::vector<int> tokens;
HWY_ASSERT(model_->Tokenizer().Encode(input, &tokens));
HWY_ASSERT(gemma_->Tokenizer().Encode(input, &tokens));
return tokens;
}
@ -69,13 +71,13 @@ class GemmaEnv {
}
std::vector<int> WrapAndTokenize(std::string& input) const {
return gcpp::WrapAndTokenize(model_->Tokenizer(), model_->ChatTemplate(),
model_->Info(), 0, input);
return gcpp::WrapAndTokenize(gemma_->Tokenizer(), gemma_->ChatTemplate(),
gemma_->Info(), 0, input);
}
std::string StringFromTokens(const std::vector<int>& tokens) const {
std::string string;
HWY_ASSERT(model_->Tokenizer().Decode(tokens, &string));
HWY_ASSERT(gemma_->Tokenizer().Decode(tokens, &string));
return string;
}
@ -99,7 +101,7 @@ class GemmaEnv {
float CrossEntropy(const std::string& input);
// Returns nullptr if the model failed to load.
Gemma* GetModel() const { return model_.get(); }
Gemma* GetGemma() const { return gemma_.get(); }
int Verbosity() const { return runtime_config_.verbosity; }
RuntimeConfig& MutableConfig() { return runtime_config_; }
@ -107,11 +109,9 @@ class GemmaEnv {
KVCache& MutableKVCache() { return kv_caches_[0]; }
private:
BoundedTopology topology_;
NestedPools pools_; // Thread pool.
MatMulEnv env_;
std::mt19937 gen_; // Random number generator.
std::unique_ptr<Gemma> model_;
std::unique_ptr<Gemma> gemma_;
std::vector<KVCache> kv_caches_; // Same number as query batch.
RuntimeConfig runtime_config_;
};
@ -119,9 +119,10 @@ class GemmaEnv {
// Logs the inference speed in tokens/sec.
void LogSpeedStats(double time_start, size_t total_tokens);
void ShowConfig(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app,
const BoundedTopology& topology, NestedPools& pools);
void ShowHelp(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app);
void ShowConfig(ThreadingArgs& threading, LoaderArgs& loader,
InferenceArgs& inference);
void ShowHelp(ThreadingArgs& threading, LoaderArgs& loader,
InferenceArgs& inference);
} // namespace gcpp

View File

@ -51,8 +51,8 @@ class GemmaTest : public ::testing::Test {
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
@ -76,7 +76,7 @@ class GemmaTest : public ::testing::Test {
}
void GenerateTokens(std::vector<std::string> &kQA, size_t num_questions) {
ASSERT_NE(s_env->GetModel(), nullptr);
ASSERT_NE(s_env->GetGemma(), nullptr);
std::vector<std::string> inputs;
for (size_t i = 0; i < num_questions; ++i) {

View File

@ -50,8 +50,8 @@ class GemmaTest : public ::testing::Test {
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
std::string mutable_prompt = prompt;
QueryResult result = s_env->QueryModel(mutable_prompt); // Uses turns.
return result.response;
@ -71,8 +71,8 @@ class GemmaTest : public ::testing::Test {
// Using the turn structure worsens results sometimes.
// However, some models need the turn structure to work.
// It would be good to make these tests more consistent.
if (s_env->GetModel()->Info().model == Model::GEMMA2_27B ||
s_env->GetModel()->Info().model == Model::GRIFFIN_2B) {
if (s_env->GetGemma()->Info().model == Model::GEMMA2_27B ||
s_env->GetGemma()->Info().model == Model::GRIFFIN_2B) {
for (QueryResult result : s_env->BatchQueryModel(inputs)) {
replies.push_back(result.response);
}
@ -96,7 +96,7 @@ class GemmaTest : public ::testing::Test {
}
void TestQuestions(const char* kQA[][2], size_t num_questions, bool batch) {
ASSERT_NE(s_env->GetModel(), nullptr);
HWY_ASSERT(s_env->GetGemma() != nullptr);
if (batch) {
std::vector<std::string> inputs;
for (size_t i = 0; i < num_questions; ++i) {
@ -155,8 +155,8 @@ TEST_F(GemmaTest, Arithmetic) {
}
TEST_F(GemmaTest, Multiturn) {
Gemma* model = s_env->GetModel();
ASSERT_NE(model, nullptr);
Gemma* model = s_env->GetGemma();
HWY_ASSERT(model != nullptr);
size_t abs_pos = 0;
std::string response;
auto stream_token = [&](int token, float) {
@ -239,12 +239,12 @@ static const char kGettysburg[] = {
"people, for the people, shall not perish from the earth.\n"};
TEST_F(GemmaTest, CrossEntropySmall) {
ASSERT_NE(s_env->GetModel(), nullptr);
HWY_ASSERT(s_env->GetGemma() != nullptr);
static const char kSmall[] =
"The capital of Hungary is Budapest which is located in Europe.";
float entropy = s_env->CrossEntropy(kSmall);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (s_env->GetModel()->Info().model) {
switch (s_env->GetGemma()->Info().model) {
case gcpp::Model::GEMMA_2B:
// 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 2.6f, 0.2f);
@ -272,10 +272,10 @@ TEST_F(GemmaTest, CrossEntropySmall) {
}
TEST_F(GemmaTest, CrossEntropyJingleBells) {
ASSERT_NE(s_env->GetModel(), nullptr);
HWY_ASSERT(s_env->GetGemma() != nullptr);
float entropy = s_env->CrossEntropy(kJingleBells);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (s_env->GetModel()->Info().model) {
switch (s_env->GetGemma()->Info().model) {
case gcpp::Model::GEMMA_2B:
// 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.9f, 0.2f);
@ -303,10 +303,10 @@ TEST_F(GemmaTest, CrossEntropyJingleBells) {
}
TEST_F(GemmaTest, CrossEntropyGettysburg) {
ASSERT_NE(s_env->GetModel(), nullptr);
HWY_ASSERT(s_env->GetGemma() != nullptr);
float entropy = s_env->CrossEntropy(kGettysburg);
fprintf(stderr, "per-token entropy: %f\n", entropy);
switch (s_env->GetModel()->Info().model) {
switch (s_env->GetGemma()->Info().model) {
case gcpp::Model::GEMMA_2B:
// 2B v.1 and v.1.1 produce slightly different results.
EXPECT_NEAR(entropy, 1.1f, 0.1f);

View File

@ -89,7 +89,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
"A", "B", "C", "D", //
" A", " B", " C", " D", //
"**", "**:", ":**", "The", "Answer", "is", ":", "."};
const TokenSet accept_set(env.GetModel()->Tokenizer(), accept_strings);
const TokenSet accept_set(env.GetGemma()->Tokenizer(), accept_strings);
for (auto sample : json_data["samples"]) {
const int id = sample["i"];
@ -131,7 +131,7 @@ void Run(GemmaEnv& env, JsonArgs& json) {
.verbosity = env.Verbosity(),
.stream_token = stream_token,
};
env.GetModel()->Generate(runtime_config, prompt, /*pos=*/0,
env.GetGemma()->Generate(runtime_config, prompt, /*pos=*/0,
env.MutableKVCache(), timing_info);
std::string output_string = env.StringFromTokens(predicted_token_ids);

View File

@ -10,13 +10,11 @@ cc_binary(
name = "hello_world",
srcs = ["run.cc"],
deps = [
# Placeholder for internal dep, do not remove.,
"//:app",
"//:args",
"//:gemma_args",
"//:gemma_lib",
"//:threading",
"//:threading_context",
"//:tokenizer",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

View File

@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
example:
```sh
./hello_world --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
./hello_world --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
```
Should print a greeting to the terminal:

View File

@ -23,23 +23,17 @@
#include <string>
#include <vector>
// Placeholder for internal header, do not modify.
#include "gemma/gemma.h"
#include "gemma/gemma_args.h" // LoaderArgs
#include "gemma/tokenizer.h"
#include "util/app.h" // LoaderArgs
#include "util/args.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
gcpp::ThreadingArgs threading(argc, argv);
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
loader.Help();
return 0;
@ -53,14 +47,14 @@ int main(int argc, char** argv) {
for (int arg = 0; arg < argc; ++arg) {
// Find a --reject flag and consume everything after it.
if (strcmp(argv[arg], "--reject") == 0) {
while (++arg < argc) reject_tokens.insert(atoi(argv[arg]));
while (++arg < argc) {
reject_tokens.insert(atoi(argv[arg])); // NOLINT
}
}
}
// Instantiate model and KV Cache
gcpp::BoundedTopology topology(gcpp::CreateTopology(app));
gcpp::NestedPools pools = gcpp::CreatePools(topology, app);
gcpp::MatMulEnv env(topology, pools);
gcpp::MatMulEnv env(MakeMatMulEnv(threading));
gcpp::Gemma model = gcpp::CreateGemma(loader, env);
gcpp::KVCache kv_cache =
gcpp::KVCache::Create(model.GetModelConfig(),

View File

@ -10,10 +10,10 @@ cc_library(
name = "gemma",
hdrs = ["gemma.hpp"],
deps = [
"//:app",
"//:gemma_args",
"//:gemma_lib",
"//:ops",
"//:threading",
"//:threading_context",
"//:tokenizer",
"@highway//:hwy",
],
@ -24,15 +24,6 @@ cc_binary(
srcs = ["run.cc"],
deps = [
":gemma",
# Placeholder for internal dep, do not remove.,
"//:app",
"//:args",
"//:common",
"//:gemma_lib",
"//:ops",
"//:threading",
"//:tokenizer",
"@highway//:hwy",
"@highway//:thread_pool",
"//:gemma_args",
],
)

View File

@ -41,7 +41,7 @@ gemma.cpp specifying the tokenizer, compressed weights file, and model type, for
example:
```sh
./simplified_gemma --tokenizer tokenizer.spm --compressed_weights 2b-it-sfp.sbs --model 2b-it
./simplified_gemma --tokenizer tokenizer.spm --weights 2b-it-sfp.sbs --model 2b-it
```
Should print a greeting to the terminal:

View File

@ -24,39 +24,22 @@
#include <vector>
#include "third_party/gemma_cpp/gemma/gemma.h"
#include "third_party/gemma_cpp/gemma/gemma_args.h" // LoaderArgs
#include "third_party/gemma_cpp/gemma/tokenizer.h"
#include "third_party/gemma_cpp/ops/matmul.h"
#include "third_party/gemma_cpp/util/app.h" // LoaderArgs
#include "third_party/gemma_cpp/util/threading.h"
#include "third_party/gemma_cpp/util/threading_context.h"
#include "third_party/highway/hwy/base.h"
class SimplifiedGemma {
public:
SimplifiedGemma(const gcpp::LoaderArgs& loader,
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs(),
const gcpp::AppArgs& app = gcpp::AppArgs())
const gcpp::ThreadingArgs& threading = gcpp::ThreadingArgs(),
const gcpp::InferenceArgs& inference = gcpp::InferenceArgs())
: loader_(loader),
threading_(threading),
inference_(inference),
app_(app),
topology_(gcpp::CreateTopology(app_)),
pools_(gcpp::CreatePools(topology_, app_)),
env_(topology_, pools_),
env_(MakeMatMulEnv(threading_)),
model_(gcpp::CreateGemma(loader_, env_)) {
Init();
}
SimplifiedGemma(int argc, char** argv)
: loader_(argc, argv, /*validate=*/true),
inference_(argc, argv),
app_(argc, argv),
topology_(gcpp::CreateTopology(app_)),
pools_(gcpp::CreatePools(topology_, app_)),
env_(topology_, pools_),
model_(gcpp::CreateGemma(loader_, env_)) {
Init();
}
void Init() {
// Instantiate model and KV Cache
kv_cache_ = gcpp::KVCache::Create(model_.GetModelConfig(),
inference_.prefill_tbatch_size);
@ -66,6 +49,11 @@ class SimplifiedGemma {
gen_.seed(rd());
}
SimplifiedGemma(int argc, char** argv)
: SimplifiedGemma(gcpp::LoaderArgs(argc, argv, /*validate=*/true),
gcpp::ThreadingArgs(argc, argv),
gcpp::InferenceArgs(argc, argv)) {}
void Generate(std::string& prompt, size_t max_generated_tokens = 1024,
float temperature = 0.7,
const std::set<int>& reject_tokens = {}) {
@ -107,10 +95,8 @@ class SimplifiedGemma {
private:
gcpp::LoaderArgs loader_;
gcpp::ThreadingArgs threading_;
gcpp::InferenceArgs inference_;
gcpp::AppArgs app_;
gcpp::BoundedTopology topology_;
gcpp::NestedPools pools_;
gcpp::MatMulEnv env_;
gcpp::Gemma model_;
gcpp::KVCache kv_cache_;

View File

@ -17,15 +17,10 @@
#include <string>
// Placeholder for internal header, do not modify.
#include "third_party/gemma_cpp/examples/simplified_gemma/gemma.hpp"
#include "util/app.h" // LoaderArgs
#include "gemma/gemma_args.h" // LoaderArgs
int main(int argc, char** argv) {
{
// Placeholder for internal init, do not modify.
}
// Standard usage: LoaderArgs takes argc and argv as input, then parses
// necessary flags.
gcpp::LoaderArgs loader(argc, argv, /*validate=*/true);
@ -35,12 +30,12 @@ int main(int argc, char** argv) {
// gcpp::LoaderArgs loader("/path/to/tokenizer", "/path/to/weights",
// "model_identifier");
// Optional: InferenceArgs and AppArgs can be passed in as well. If not
// Optional: ThreadingArgs and InferenceArgs can be passed in as well. If not
// specified, default values will be used.
//
// gcpp::InferenceArgs inference(argc, argv);
// gcpp::AppArgs app(argc, argv);
// SimplifiedGemma gemma(loader, inference, app);
// gcpp::ThreadingArgs threading(argc, argv);
// SimplifiedGemma gemma(loader, threading, inference);
SimplifiedGemma gemma(loader);
std::string prompt = "Write a greeting to the world.";

View File

@ -18,14 +18,12 @@
#include <stddef.h>
#include "compression/shared.h" // BF16
#include "gemma/configs.h"
#include "ops/matmul.h" // MatMulEnv
#include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // RowVectorBatch
#include "util/threading.h"
#include "hwy/base.h" // HWY_DASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "gemma/configs.h" // ModelConfig
#include "ops/matmul.h" // MatMulEnv
#include "ops/ops.h" // CreateInvTimescale
#include "util/allocator.h" // Allocator
#include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch
namespace gcpp {
@ -74,6 +72,8 @@ struct Activations {
size_t cache_pos_size = 0;
void Allocate(size_t batch_size, MatMulEnv* env) {
const Allocator2& allocator = env->ctx.allocator;
post_qk = layer_config.post_qk;
const size_t model_dim = weights_config.model_dim;
const size_t ff_hidden_dim = layer_config.ff_hidden_dim;
@ -81,36 +81,45 @@ struct Activations {
const size_t qkv_dim = layer_config.qkv_dim;
const size_t heads = layer_config.heads;
x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
x = RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
q = RowVectorBatch<float>(
Extents2D(batch_size, heads * layer_config.QStride()));
allocator, Extents2D(batch_size, heads * layer_config.QStride()));
if (vocab_size > 0) {
logits = RowVectorBatch<float>(Extents2D(batch_size, vocab_size));
logits =
RowVectorBatch<float>(allocator, Extents2D(batch_size, vocab_size));
}
pre_att_rms_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
pre_att_rms_out =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
att = RowVectorBatch<float>(
Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(Extents2D(batch_size, heads * qkv_dim));
att_sums = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
allocator, Extents2D(batch_size, heads * weights_config.seq_len));
att_out = RowVectorBatch<float>(allocator,
Extents2D(batch_size, heads * qkv_dim));
att_sums =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
bf_pre_ffw_rms_out = RowVectorBatch<BF16>(Extents2D(batch_size, model_dim));
C1 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
C2 = RowVectorBatch<float>(Extents2D(batch_size, ff_hidden_dim));
ffw_out = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
bf_pre_ffw_rms_out =
RowVectorBatch<BF16>(allocator, Extents2D(batch_size, model_dim));
C1 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
C2 = RowVectorBatch<float>(allocator, Extents2D(batch_size, ff_hidden_dim));
ffw_out =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
if (layer_config.type == LayerAttentionType::kGriffinRecurrentBlock) {
griffin_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_y = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_gate_x = RowVectorBatch<float>(Extents2D(batch_size, model_dim));
griffin_x =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_y =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_gate_x =
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
griffin_multiplier =
RowVectorBatch<float>(Extents2D(batch_size, model_dim));
RowVectorBatch<float>(allocator, Extents2D(batch_size, model_dim));
}
inv_timescale = CreateInvTimescale(layer_config.qkv_dim,
inv_timescale = CreateInvTimescale(allocator, layer_config.qkv_dim,
post_qk == PostQKType::HalfRope);
inv_timescale_global =
CreateInvTimescale(qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
inv_timescale_global = CreateInvTimescale(
allocator, qkv_dim, post_qk == PostQKType::HalfRope, 1000000.0);
this->env = env;
}

View File

@ -17,13 +17,13 @@
#include <math.h> // sqrtf
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <algorithm> // std::min
#include <cstdio>
#include <vector>
#include "compression/compress.h"
#include "gemma/activations.h"
#include "gemma/common.h"
#include "gemma/configs.h"
@ -32,11 +32,9 @@
#include "gemma/weights.h"
#include "paligemma/image.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/bit_set.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
@ -82,7 +80,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
const KVCaches& kv_caches) {
PROFILER_ZONE("Gen.Griffin");
KVCache& kv_cache = kv_caches[0];
hwy::ThreadPool& pool = activations.env->parallel.Pools().Pool(0);
hwy::ThreadPool& pool = activations.env->ctx.pools.Pool(0);
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
const size_t model_dim = layer_weights->layer_config.model_dim;
@ -96,8 +94,8 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
TwoMatVecAdd(layer_weights->griffin.linear_x_w,
layer_weights->griffin.linear_y_w, 0, model_dim, model_dim,
activations.pre_att_rms_out.Batch(batch_idx),
/*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(),
/*add0=*/layer_weights->griffin.linear_x_biases.PackedScale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.PackedScale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, model_dim);
}
@ -121,15 +119,15 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
for (size_t i = 0; i < model_dim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i);
hn::Load(df, layer_weights->griffin.conv_biases.PackedScale1() + i);
auto accum1 = hn::Zero(df);
HWY_ASSERT_M(conv_1d_width % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < conv_1d_width; l++) {
auto wv0 =
hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 1 - 2 * l) * model_dim + i);
auto wv1 =
hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
hn::Load(df, layer_weights->griffin.conv_w.PackedScale1() +
(conv_1d_width - 2 - 2 * l) * model_dim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
@ -156,9 +154,9 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
TwoOfsMatVecAddLoop(
layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (heads + head), kHeadDim, kHeadDim, x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.data_scale1() +
/*add0=*/layer_weights->griffin.gate_biases.PackedScale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data_scale1() +
/*add1=*/layer_weights->griffin.gate_biases.PackedScale1() +
model_dim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
@ -166,7 +164,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.data_scale1() + head_offset,
layer_weights->griffin.a.PackedScale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
@ -198,7 +196,7 @@ HWY_NOINLINE void GriffinRecurrent(size_t batch_start, size_t num_tokens,
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* out_ptr = activations.att_sums.Batch(batch_idx);
MatVecAdd(layer_weights->griffin.linear_out_w, 0, model_dim, model_dim, x,
layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr,
layer_weights->griffin.linear_out_biases.PackedScale1(), out_ptr,
pool);
}
}
@ -253,7 +251,7 @@ class GemmaAttention {
const auto pre_att_rms_out =
ConstMatFromBatch(num_interleaved, activations_.pre_att_rms_out);
auto w_q1 = layer_weights_.qkv_einsum_w.data()
auto w_q1 = layer_weights_.qkv_einsum_w.HasPtr()
? ConstMatFromWeights(layer_weights_.qkv_einsum_w)
: ConstMatFromWeights(layer_weights_.qkv_einsum_w1);
// The original qkv_einsum_w has shape [(heads + kv_heads * 2), kKQVDim,
@ -265,15 +263,20 @@ class GemmaAttention {
const size_t w1_rows = heads * layer_config_.QStride();
w_q1.ShrinkRows(w1_rows);
MatMul(pre_att_rms_out, w_q1,
/*add=*/nullptr, *activations_.env, RowPtrFromBatch(activations_.q));
/*add=*/nullptr, *activations_.env,
RowPtrFromBatch(allocator_, activations_.q));
if (is_mha_) {
// Multi-Head Attention a.k.a. "use_qkv_einsum" computed QKV already.
} else {
auto w_q2 = layer_weights_.qkv_einsum_w.data()
? ConstMatFromWeights(layer_weights_.qkv_einsum_w,
w1_rows * model_dim)
: ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
decltype(w_q1) w_q2;
if (layer_weights_.qkv_einsum_w.HasPtr()) {
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w);
// Skip first half of the matrix.
w_q2.ofs = w_q2.Row(w1_rows);
} else {
w_q2 = ConstMatFromWeights(layer_weights_.qkv_einsum_w2);
}
// KV structure is [k, v, k, v, ....] = kv_heads pairs of (k, v).
const size_t w_rows_kv_cols = kv_heads * 2 * qkv_dim;
w_q2.ShrinkRows(w_rows_kv_cols);
@ -285,7 +288,7 @@ class GemmaAttention {
const size_t kv_ofs =
queries_pos_[0] * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_caches_[0].kv_cache.get() + kv_ofs;
RowPtrF kv_rows(kv, w_rows_kv_cols);
RowPtrF kv_rows(allocator_, kv, w_rows_kv_cols);
kv_rows.SetStride(cache_pos_size_);
MatMul(pre_att_rms_out, w_q2,
/*add=*/nullptr, *activations_.env, kv_rows);
@ -302,7 +305,7 @@ class GemmaAttention {
const size_t kv_offset =
cache_pos * cache_pos_size_ + layer_ * cache_layer_size_;
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
if (layer_weights_.qkv_einsum_w.data()) {
if (layer_weights_.qkv_einsum_w.HasPtr()) {
MatVec(layer_weights_.qkv_einsum_w, heads * qkv_dim * model_dim,
w_rows_kv_cols, model_dim, x, kv, pool_);
} else {
@ -336,8 +339,8 @@ class GemmaAttention {
}
// Apply further processing to K.
if (layer_weights_.key_norm_scale.data()) {
RMSNormInplace(layer_weights_.key_norm_scale.data(), kv,
if (layer_weights_.key_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.key_norm_scale.Row(0), kv,
qkv_dim);
}
PositionalEncodingQK(kv, pos, layer_, /*mul=*/1.0f);
@ -427,8 +430,8 @@ class GemmaAttention {
// Apply rope and scaling to Q.
const size_t pos = queries_pos_[query_idx] + batch_idx;
if (layer_weights_.query_norm_scale.data()) {
RMSNormInplace(layer_weights_.query_norm_scale.data(), q,
if (layer_weights_.query_norm_scale.HasPtr()) {
RMSNormInplace(layer_weights_.query_norm_scale.Row(0), q,
qkv_dim);
}
PositionalEncodingQK(q, pos, layer_, query_scale);
@ -473,17 +476,18 @@ class GemmaAttention {
HWY_DASSERT(layer_config_.model_dim > 0);
HWY_DASSERT(layer_config_.heads > 0);
HWY_DASSERT(layer_config_.qkv_dim > 0);
HWY_DASSERT(layer_weights_.att_weights.data() != nullptr);
HWY_DASSERT(layer_weights_.att_weights.HasPtr());
HWY_DASSERT(activations_.att_out.All() != nullptr);
HWY_DASSERT(activations_.att_sums.All() != nullptr);
const float* add =
layer_weights_.layer_config.softmax_attn_output_biases
? layer_weights_.attention_output_biases.data_scale1()
? layer_weights_.attention_output_biases.PackedScale1()
: nullptr;
MatMul(ConstMatFromBatch(num_interleaved, activations_.att_out),
ConstMatFromWeights(layer_weights_.att_weights), add,
*activations_.env, RowPtrFromBatch(activations_.att_sums));
*activations_.env,
RowPtrFromBatch(allocator_, activations_.att_sums));
}
public:
@ -533,7 +537,8 @@ class GemmaAttention {
layer_weights_(*layer_weights),
div_seq_len_(div_seq_len),
kv_caches_(kv_caches),
pool_(activations.env->parallel.Pools().Pool(0)) {
allocator_(activations.env->ctx.allocator),
pool_(activations.env->ctx.pools.Pool(0)) {
HWY_DASSERT(num_queries_ <= kv_caches_.size());
HWY_DASSERT_M((layer_config_.heads % layer_config_.kv_heads) == 0,
"query heads must be a multiple of key-value heads");
@ -562,6 +567,7 @@ class GemmaAttention {
const LayerWeightsPtrs<T>& layer_weights_;
const hwy::Divisor& div_seq_len_;
const KVCaches& kv_caches_;
const Allocator2& allocator_;
hwy::ThreadPool& pool_;
};
@ -606,8 +612,8 @@ class VitAttention {
HWY_ASSERT(qkv.Cols() == layer_config_.heads * 3 * layer_config_.qkv_dim);
MatMul(ConstMatFromBatch(num_tokens_, activations_.pre_att_rms_out),
ConstMatFromWeights(layer_weights_.vit.qkv_einsum_w),
layer_weights_.vit.qkv_einsum_b.data_scale1(), *activations_.env,
RowPtrFromBatch(qkv));
layer_weights_.vit.qkv_einsum_b.PackedScale1(), *activations_.env,
RowPtrFromBatch(allocator_, qkv));
}
// TODO(philculliton): transition fully to MatMul.
@ -621,10 +627,10 @@ class VitAttention {
// Shift Q, K, VT to RowVectorBatches with AllocateAlignedRows(extents)
RowVectorBatch<float> Q =
AllocateAlignedRows<float>(Extents2D(num_tokens_, qkv_dim));
AllocateAlignedRows<float>(allocator_, Extents2D(num_tokens_, qkv_dim));
RowVectorBatch<float> K =
AllocateAlignedRows<float>(Extents2D(seq_len, qkv_dim));
RowVectorBatch<float> C(Extents2D(num_tokens_, seq_len));
AllocateAlignedRows<float>(allocator_, Extents2D(seq_len, qkv_dim));
RowVectorBatch<float> C(allocator_, Extents2D(num_tokens_, seq_len));
// Initialize att_out to zero prior to head loop.
hwy::ZeroBytes(activations_.att_out.All(),
@ -650,7 +656,7 @@ class VitAttention {
// this produces C, a (num_tokens_, seq_len) matrix of dot products
MatMul(ConstMatFromBatch(Q.BatchSize(), Q),
ConstMatFromBatch(K.BatchSize(), K), nullptr, *activations_.env,
RowPtrFromBatch(C));
RowPtrFromBatch(allocator_, C));
pool_.Run(0, num_tokens_, [&](uint64_t task, size_t /*thread*/) HWY_ATTR {
float* HWY_RESTRICT c = C.Batch(task);
@ -712,13 +718,13 @@ class VitAttention {
// head_dim (`qkv_dim`) into output (`att_sums`).
HWY_NOINLINE void SumHeads() {
PROFILER_ZONE("Gen.VitAttention.SumHeads");
auto* bias = layer_weights_.vit.attn_out_b.data_scale1();
auto* bias = layer_weights_.vit.attn_out_b.PackedScale1();
// att_weights and att_out are concatenated heads, each of length
// qkv_dim. Thus the [num_tokens_, layer_config_.model_dim]
// matmul output is the sum over heads.
auto att_out = ConstMatFromBatch(num_tokens_, activations_.att_out);
auto att_weights = ConstMatFromWeights(layer_weights_.vit.attn_out_w);
auto att_sums = RowPtrFromBatch(activations_.att_sums);
auto att_sums = RowPtrFromBatch(allocator_, activations_.att_sums);
MatMul(att_out, att_weights, bias, *activations_.env, att_sums);
}
@ -730,7 +736,8 @@ class VitAttention {
activations_(activations),
layer_weights_(*layer_weights),
layer_config_(layer_weights->layer_config),
pool_(activations.env->parallel.Pools().Pool(0)) {}
allocator_(activations.env->ctx.allocator),
pool_(activations.env->ctx.pools.Pool(0)) {}
HWY_INLINE void operator()() {
ComputeQKV();
@ -748,6 +755,7 @@ class VitAttention {
Activations& activations_;
const LayerWeightsPtrs<T>& layer_weights_;
const LayerConfig& layer_config_;
const Allocator2& allocator_;
hwy::ThreadPool& pool_;
};
@ -779,32 +787,35 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
const bool add_bias = layer_weights->layer_config.ff_biases;
const float* bias1 =
add_bias ? layer_weights->ffw_gating_biases.data_scale1() : nullptr;
add_bias ? layer_weights->ffw_gating_biases.PackedScale1() : nullptr;
const float* bias2 = add_bias ? bias1 + ffh_hidden_dim : nullptr;
const float* output_bias =
add_bias ? layer_weights->ffw_output_biases.data_scale1() : nullptr;
add_bias ? layer_weights->ffw_output_biases.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
auto hidden_activations = RowPtrFromBatch(activations.C1);
auto multiplier = RowPtrFromBatch(activations.C2);
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
const Allocator2& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto multiplier = RowPtrFromBatch(allocator, activations.C2);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
// gating_einsum_w holds two half-matrices. We plan to change the importer to
// avoid this confusion by splitting into gating_einsum_w1 and
// gating_einsum_w2.
const bool split = !!layer_weights->gating_einsum_w.data();
const bool split = layer_weights->gating_einsum_w.HasPtr();
auto w1 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w)
: ConstMatFromWeights(layer_weights->gating_einsum_w1);
auto w2 = split ? ConstMatFromWeights(layer_weights->gating_einsum_w,
model_dim * ffh_hidden_dim)
: ConstMatFromWeights(layer_weights->gating_einsum_w2);
decltype(w1) w2;
if (split) {
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w);
w2.ofs = w2.Row(ffh_hidden_dim);
// Ensure that B.Extents().row matches C.Cols() because MatMul checks that.
w1.ShrinkRows(ffh_hidden_dim);
w2.ShrinkRows(ffh_hidden_dim);
} else {
w2 = ConstMatFromWeights(layer_weights->gating_einsum_w2);
}
auto w_output = ConstMatFromWeights(layer_weights->linear_w);
@ -835,16 +846,17 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
const bool add_bias = layer_weights->layer_config.ff_biases;
const float* bias1 =
add_bias ? layer_weights->vit.linear_0_b.data_scale1() : nullptr;
add_bias ? layer_weights->vit.linear_0_b.PackedScale1() : nullptr;
const float* output_bias =
add_bias ? layer_weights->vit.linear_1_b.data_scale1() : nullptr;
add_bias ? layer_weights->vit.linear_1_b.PackedScale1() : nullptr;
// Define slightly more readable names for the weights and activations.
const auto x =
ConstMatFromBatch(num_interleaved, activations.bf_pre_ffw_rms_out);
auto hidden_activations = RowPtrFromBatch(activations.C1);
auto ffw_out = RowPtrFromBatch(activations.ffw_out);
const Allocator2& allocator = activations.env->ctx.allocator;
auto hidden_activations = RowPtrFromBatch(allocator, activations.C1);
auto ffw_out = RowPtrFromBatch(allocator, activations.ffw_out);
auto w1 = ConstMatFromWeights(layer_weights->vit.linear_0_w);
auto w_output = ConstMatFromWeights(layer_weights->vit.linear_1_w);
@ -853,7 +865,7 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
MatMul(x, w1, bias1, *activations.env, hidden_activations);
// Activation (Gelu), store in act.
RowPtrF multiplier = RowPtrF(nullptr, 0);
RowPtrF multiplier = RowPtrF(allocator, nullptr, 0);
Activation(layer_weights->layer_config.activation, hidden_activations.Row(0),
multiplier.Row(0), ff_hidden_dim * num_interleaved);
@ -905,11 +917,9 @@ HWY_NOINLINE void EmbedMMToken(int token, size_t batch_idx, size_t pos,
HWY_DASSERT(token < static_cast<int>(vocab_size));
const hn::ScalableTag<float> df;
DecompressAndZeroPad(
df,
MakeSpan(weights.embedder_input_embedding.data(), vocab_size * model_dim),
token * model_dim, x.Batch(batch_idx), model_dim);
MulByConst(emb_scaling * weights.embedder_input_embedding.scale(),
DecompressAndZeroPad(df, weights.embedder_input_embedding.Span(),
token * model_dim, x.Batch(batch_idx), model_dim);
MulByConst(emb_scaling * weights.embedder_input_embedding.Scale(),
x.Batch(batch_idx), model_dim);
if (weights.weights_config.absolute_pe) {
AddAbsolutePositionalEmbeddings(x.Batch(batch_idx), model_dim, pos);
@ -943,9 +953,10 @@ HWY_NOINLINE void ResidualConnection(
template <typename WeightT, typename InOutT>
void PostNorm(PostNormType post_norm, size_t num_interleaved,
const WeightT& weights, InOutT* inout) {
HWY_DASSERT(weights.Rows() == 1);
if (post_norm == PostNormType::Scale) {
RMSNormInplaceBatched(num_interleaved, weights.data_scale1(), inout,
weights.NumElements());
RMSNormInplaceBatched(num_interleaved, weights.PackedScale1(), inout,
weights.Cols());
}
}
@ -962,7 +973,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
auto type = layer_weights->layer_config.type;
RMSNormBatched(num_interleaved, activations.x.All(),
layer_weights->pre_attention_norm_scale.data_scale1(),
layer_weights->pre_attention_norm_scale.PackedScale1(),
activations.pre_att_rms_out.All(), model_dim);
Attention(type, queries_pos, queries_prefix_end, num_tokens, cache_layer_idx,
@ -976,7 +987,7 @@ HWY_NOINLINE void TransformerLayer(const QueriesPos& queries_pos,
activations.x.All(), layer_weights, /*is_attention=*/true);
RMSNormBatched(num_interleaved, activations.x.All(),
layer_weights->pre_ffw_norm_scale.data_scale1(),
layer_weights->pre_ffw_norm_scale.PackedScale1(),
activations.bf_pre_ffw_rms_out.All(), model_dim);
if (layer_weights->layer_config.type == LayerAttentionType::kVit) {
@ -1014,8 +1025,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
// y = nn.LayerNorm()(x)
// y ~ pre_att_rms_out
LayerNormBatched(num_tokens, x.All(),
layer_weights->vit.layer_norm_0_scale.data_scale1(),
layer_weights->vit.layer_norm_0_bias.data_scale1(),
layer_weights->vit.layer_norm_0_scale.PackedScale1(),
layer_weights->vit.layer_norm_0_bias.PackedScale1(),
activations.pre_att_rms_out.All(), model_dim);
// y = out["sa"] = nn.MultiHeadDotProductAttention(...)(y, y)
@ -1028,8 +1039,8 @@ HWY_NOINLINE void VitTransformerLayer(size_t num_tokens, size_t layer,
// y = nn.LayerNorm()(x)
// y ~ bf_pre_ffw_rms_out
LayerNormBatched(num_tokens, x.All(),
layer_weights->vit.layer_norm_1_scale.data_scale1(),
layer_weights->vit.layer_norm_1_bias.data_scale1(),
layer_weights->vit.layer_norm_1_scale.PackedScale1(),
layer_weights->vit.layer_norm_1_bias.PackedScale1(),
activations.bf_pre_ffw_rms_out.All(), model_dim);
// y = out["mlp"] = MlpBlock(...)(y)
@ -1161,8 +1172,8 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
const size_t patch_width = weights.weights_config.vit_config.patch_width;
const size_t seq_len = weights.weights_config.vit_config.seq_len;
const size_t patch_size = patch_width * patch_width * 3;
HWY_DASSERT(weights.vit_img_embedding_kernel.NumElements() ==
patch_size * model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Rows() == model_dim);
HWY_DASSERT(weights.vit_img_embedding_kernel.Cols() == patch_size);
HWY_DASSERT(activations.x.Cols() == model_dim);
std::vector<hwy::AlignedFreeUniquePtr<float[]>> image_patches(seq_len);
for (size_t i = 0; i < seq_len; ++i) {
@ -1178,20 +1189,20 @@ HWY_NOINLINE void EmbedImagePatches(const Image& image,
// MatMul(
// MatFromBatch(kVitSeqLen, image_patches),
// MatFromWeights(weights.vit_img_embedding_kernel),
// weights.vit_img_embedding_bias.data_scale1(), *activations.env,
// weights.vit_img_embedding_bias.PackedScale1(), *activations.env,
// RowPtrF(activations.x.All(), kVitModelDim));
// However, MatMul currently requires that
// A.cols % (2 * hn::Lanes(hn::ScalableTag<MulT>())) == 0
// which is not the case here. We should relax that requirement on MatMul and
// then use the above. For now, we rely on MatVecAdd instead.
for (size_t i = 0; i < seq_len; ++i) {
MatVecAdd(
weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(), weights.vit_img_embedding_bias.data_scale1(),
activations.x.Batch(i), activations.env->parallel.Pools().Pool(0));
MatVecAdd(weights.vit_img_embedding_kernel, 0, model_dim, patch_size,
image_patches[i].get(),
weights.vit_img_embedding_bias.PackedScale1(),
activations.x.Batch(i), activations.env->ctx.pools.Pool(0));
}
// Add position embeddings.
AddFrom(weights.vit_img_pos_embedding.data_scale1(), activations.x.All(),
AddFrom(weights.vit_img_pos_embedding.PackedScale1(), activations.x.All(),
seq_len * model_dim);
}
@ -1216,23 +1227,23 @@ HWY_NOINLINE void PrefillVit(const ModelWeightsPtrs<T>& weights,
}
// Final Layernorm.
LayerNormBatched(num_tokens, activations.x.All(),
weights.vit_encoder_norm_scale.data_scale1(),
weights.vit_encoder_norm_bias.data_scale1(),
weights.vit_encoder_norm_scale.PackedScale1(),
weights.vit_encoder_norm_bias.PackedScale1(),
activations.x.All(), vit_model_dim);
if (weights.weights_config.wrapping == PromptWrapping::GEMMA_VLM) {
activations.x = AvgPool4x4(activations.x);
// Apply soft embedding norm before input projection.
RMSNormInplace(weights.mm_embed_norm.data_scale1(), activations.x.All(),
RMSNormInplace(weights.mm_embed_norm.PackedScale1(), activations.x.All(),
vit_model_dim);
}
// Apply head embedding into image_tokens of size of the LLM kModelDim.
MatMul(ConstMatFromBatch(activations.x.BatchSize(), activations.x),
ConstMatFromWeights(weights.vit_img_head_kernel),
weights.vit_img_head_bias.data_scale1(), *activations.env,
RowPtrFromBatch(image_tokens));
weights.vit_img_head_bias.PackedScale1(), *activations.env,
RowPtrFromBatch(activations.env->ctx.allocator, image_tokens));
}
// Generates one token for each query. `queries_token` is the previous token
@ -1274,7 +1285,7 @@ HWY_NOINLINE void Transformer(
}
}
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.data_scale1(),
RMSNormInplaceBatched(num_queries, weights.final_norm_scale.PackedScale1(),
activations.x.All(), model_dim);
if (activations_observer) {
@ -1374,7 +1385,7 @@ bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
MatMul(ConstMatFromBatch(num_queries, activations.x),
ConstMatFromWeights(weights.embedder_input_embedding),
/*add=*/nullptr, *activations.env,
RowPtrFromBatch(activations.logits));
RowPtrFromBatch(activations.env->ctx.allocator, activations.logits));
}
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {

View File

@ -27,22 +27,33 @@
#include <utility> // std::move
#include <vector>
#include "compression/io.h" // Path
// Placeholder for internal header, do not modify.
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/tokenizer.h"
#include "gemma/weights.h"
#include "ops/ops-inl.h"
#include "ops/matmul.h"
#include "paligemma/image.h"
#include "util/threading.h"
#include "hwy/highway.h"
#include "util/threading_context.h"
#include "hwy/base.h"
namespace gcpp {
// Internal init must run before I/O; calling it from `GemmaEnv()` is too late.
// This helper function takes care of the internal init plus calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args) {
// Placeholder for internal init, do not modify.
ThreadingContext2::SetArgs(threading_args);
return MatMulEnv(ThreadingContext2::Get());
}
Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
const ModelInfo& info, MatMulEnv& env)
: env_(env), tokenizer_(tokenizer_path) {
model_.Load(weights, info.model, info.weight, info.wrapping,
env_.parallel.Pools().Pool(0),
env_.ctx.pools.Pool(0),
/*tokenizer_proto=*/nullptr);
chat_template_.Init(tokenizer_, model_.Config().model);
}
@ -50,7 +61,7 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights,
Gemma::Gemma(const Path& weights, MatMulEnv& env) : env_(env) {
std::string tokenizer_proto;
model_.Load(weights, Model::UNKNOWN, Type::kUnknown, PromptWrapping::GEMMA_IT,
env_.parallel.Pools().Pool(0), &tokenizer_proto);
env_.ctx.pools.Pool(0), &tokenizer_proto);
tokenizer_.Deserialize(tokenizer_proto);
chat_template_.Init(tokenizer_, model_.Config().model);
}
@ -60,7 +71,7 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env)
tokenizer_(std::move(tokenizer)),
chat_template_(tokenizer_, info.model) {
HWY_ASSERT(info.weight == Type::kF32);
model_.Allocate(info.model, info.weight, env_.parallel.Pools().Pool(0));
model_.Allocate(info.model, info.weight, env_.ctx.pools.Pool(0));
}
Gemma::~Gemma() {
@ -130,12 +141,12 @@ struct GenerateImageTokensT {
void Gemma::Generate(const RuntimeConfig& runtime_config,
const PromptTokens& prompt, size_t pos, size_t prefix_end,
KVCache& kv_cache, TimingInfo& timing_info) {
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateSingleT>(
runtime_config, prompt, pos, prefix_end, kv_cache, &env_, timing_info);
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
@ -152,23 +163,23 @@ void Gemma::GenerateBatch(const RuntimeConfig& runtime_config,
QueriesPos(prefix_end_vec.data(), prefix_end_vec.size());
}
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateBatchT>(
runtime_config, queries_prompt, queries_pos, mutable_queries_prefix_end,
kv_caches, &env_, timing_info);
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
void Gemma::GenerateImageTokens(const RuntimeConfig& runtime_config,
const Image& image, ImageTokens& image_tokens) {
env_.parallel.Pools().MaybeStartSpinning(runtime_config.use_spinning);
env_.ctx.pools.MaybeStartSpinning(runtime_config.use_spinning);
model_.CallForModelWeight<GenerateImageTokensT>(runtime_config, image,
image_tokens, &env_);
env_.parallel.Pools().MaybeStopSpinning(runtime_config.use_spinning);
env_.ctx.pools.MaybeStopSpinning(runtime_config.use_spinning);
}
// Non-template functions moved from gemma-inl.h to avoid ODR violations.

View File

@ -16,6 +16,8 @@
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_GEMMA_H_
#include <stdio.h>
#include <functional>
#include <random>
#include <string>
@ -31,8 +33,9 @@
#include "gemma/weights.h"
#include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h"
#include "util/allocator.h" // RowVectorBatch
#include "util/basics.h" // TokenAndProb
#include "util/basics.h" // TokenAndProb
#include "util/mat.h" // RowVectorBatch
#include "util/threading_context.h"
#include "hwy/timer.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h" // Span
@ -193,6 +196,10 @@ struct TimingInfo {
size_t tokens_generated = 0;
};
// Internal init must run before I/O; calling it from GemmaEnv() is too late.
// This helper function takes care of the internal init plus calling `SetArgs`.
MatMulEnv MakeMatMulEnv(const ThreadingArgs& threading_args);
class Gemma {
public:
// Reads old format weights file and tokenizer file.
@ -206,7 +213,9 @@ class Gemma {
Gemma(GemmaTokenizer&& tokenizer, const ModelInfo& info, MatMulEnv& env);
~Gemma();
MatMulEnv& Env() const { return env_; }
const ModelConfig& GetModelConfig() const { return model_.Config(); }
// DEPRECATED
ModelInfo Info() const {
return ModelInfo({.model = model_.Config().model,
.wrapping = model_.Config().wrapping,

View File

@ -15,8 +15,8 @@
// Shared between various frontends.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#ifndef THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#define THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_
#include <stddef.h>
#include <stdio.h>
@ -31,103 +31,10 @@
#include "ops/matmul.h"
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "util/threading.h"
#include "hwy/base.h" // HWY_IS_ASAN
#include "hwy/base.h" // HWY_ABORT
namespace gcpp {
static inline const char* CompiledConfig() {
if (HWY_IS_ASAN) {
return "asan";
} else if (HWY_IS_MSAN) {
return "msan";
} else if (HWY_IS_TSAN) {
return "tsan";
} else if (HWY_IS_HWASAN) {
return "hwasan";
} else if (HWY_IS_UBSAN) {
return "ubsan";
} else if (HWY_IS_DEBUG_BUILD) {
return "dbg";
} else {
return "opt";
}
}
class AppArgs : public ArgsBase<AppArgs> {
public:
AppArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
AppArgs() { Init(); };
int verbosity;
size_t max_threads; // divided among the detected clusters
Tristate pin; // pin threads?
Tristate spin; // use spin waits?
// For BoundedSlice:
size_t skip_packages;
size_t max_packages;
size_t skip_clusters;
size_t max_clusters;
size_t skip_lps;
size_t max_lps;
std::string eot_line;
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
2);
// The exact meaning is more subtle: see the comment at NestedPools ctor.
visitor(max_threads, "num_threads", size_t{0},
"Maximum number of threads to use; default 0 = unlimited.", 2);
visitor(pin, "pin", Tristate::kDefault,
"Pin threads? -1 = auto, 0 = no, 1 = yes.", 2);
visitor(spin, "spin", Tristate::kDefault,
"Use spin waits? -1 = auto, 0 = no, 1 = yes.", 2);
// These can be used to partition CPU sockets/packages and their
// clusters/CCXs across several program instances. The default is to use
// all available resources.
visitor(skip_packages, "skip_packages", size_t{0},
"Index of the first socket to use; default 0 = unlimited.", 2);
visitor(max_packages, "max_packages", size_t{0},
"Maximum number of sockets to use; default 0 = unlimited.", 2);
visitor(skip_clusters, "skip_clusters", size_t{0},
"Index of the first CCX to use; default 0 = unlimited.", 2);
visitor(max_clusters, "max_clusters", size_t{0},
"Maximum number of CCXs to use; default 0 = unlimited.", 2);
// These are only used when CPU topology is unknown.
visitor(skip_lps, "skip_lps", size_t{0},
"Index of the first LP to use; default 0 = unlimited.", 2);
visitor(max_lps, "max_lps", size_t{0},
"Maximum number of LPs to use; default 0 = unlimited.", 2);
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
2);
}
};
static inline BoundedTopology CreateTopology(const AppArgs& app) {
return BoundedTopology(BoundedSlice(app.skip_packages, app.max_packages),
BoundedSlice(app.skip_clusters, app.max_clusters),
BoundedSlice(app.skip_lps, app.max_lps));
}
static inline NestedPools CreatePools(const BoundedTopology& topology,
const AppArgs& app) {
Allocator::Init(topology);
return NestedPools(topology, app.max_threads, app.pin);
}
struct LoaderArgs : public ArgsBase<LoaderArgs> {
LoaderArgs(int argc, char* argv[], bool validate = true) {
InitAndParse(argc, argv);
@ -154,15 +61,6 @@ struct LoaderArgs : public ArgsBase<LoaderArgs> {
// Returns error string or nullptr if OK.
const char* Validate() {
if (!compressed_weights.path.empty()) {
if (weights.path.empty()) {
weights = compressed_weights;
} else {
return "Only one of --weights and --compressed_weights can be "
"specified. To create compressed weights use the "
"compress_weights tool.";
}
}
if (weights.path.empty()) {
return "Missing --weights flag, a file for the model weights.";
}
@ -250,6 +148,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
InferenceArgs(int argc, char* argv[]) { InitAndParse(argc, argv); }
InferenceArgs() { Init(); };
int verbosity;
size_t max_generated_tokens;
size_t prefill_tbatch_size;
@ -261,6 +161,8 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
bool multiturn;
Path image_file;
std::string eot_line;
// Returns error string or nullptr if OK.
const char* Validate() const {
if (max_generated_tokens > gcpp::kSeqLen) {
@ -272,6 +174,12 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
template <class Visitor>
void ForEach(const Visitor& visitor) {
visitor(verbosity, "verbosity", 1,
"Show verbose developer information\n 0 = only print generation "
"output\n 1 = standard user-facing terminal ui\n 2 = show "
"developer/debug info).\n Default = 1.",
2);
visitor(max_generated_tokens, "max_generated_tokens", size_t{2048},
"Maximum number of tokens to generate.");
@ -291,6 +199,14 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
" Default : 0 (conversation "
"resets every turn)");
visitor(image_file, "image_file", Path(), "Image file to load.");
visitor(
eot_line, "eot_line", std::string(""),
"End of turn line. "
"When you specify this, the prompt will be all lines "
"before the line where only the given string appears.\n Default = "
"When a newline is encountered, that signals the end of the turn.",
2);
}
void CopyTo(RuntimeConfig& runtime_config) const {
@ -317,4 +233,4 @@ struct InferenceArgs : public ArgsBase<InferenceArgs> {
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_APP_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_ARGS_H_

View File

@ -15,23 +15,23 @@
// Command line text interface to gemma.
#include <stdio.h>
#include <iostream>
#include <random>
#include <string>
#include <string_view>
#include <vector>
// Placeholder for internal header, do not modify.
#include "compression/shared.h" // PromptWrapping
#include "evals/benchmark_helper.h"
#include "gemma/common.h"
#include "gemma/gemma.h" // Gemma
#include "ops/matmul.h" // MatMulEnv
#include "gemma/gemma_args.h" // LoaderArgs
#include "ops/matmul.h" // MatMulEnv
#include "paligemma/image.h"
#include "util/app.h"
#include "util/args.h" // HasHelp
#include "util/threading.h"
#include "hwy/base.h"
#include "util/threading_context.h"
#include "hwy/highway.h"
#include "hwy/profiler.h"
@ -78,35 +78,37 @@ std::string GetPrompt(std::istream& input, int verbosity,
}
// The main Read-Eval-Print Loop.
void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
const InferenceArgs& args, const AcceptFunc& accept_token,
std::string& eot_line) {
void ReplGemma(const ThreadingArgs& threading, const InferenceArgs& inference,
Gemma& model, KVCache& kv_cache) {
PROFILER_ZONE("Gen.misc");
size_t abs_pos = 0; // across turns
size_t tokens_generated_this_turn = 0; // differentiates prefill from reply
size_t prompt_size = 0;
std::mt19937 gen;
InitGenerator(args, gen);
InitGenerator(inference, gen);
const bool have_image = !args.image_file.path.empty();
const bool have_image = !inference.image_file.path.empty();
Image image;
ImageTokens image_tokens;
if (have_image) {
size_t pool_dim = model.GetModelConfig().vit_config.pool_dim;
image_tokens = ImageTokens(Extents2D(
model.GetModelConfig().vit_config.seq_len / (pool_dim * pool_dim),
model.GetModelConfig().model_dim));
image_tokens =
ImageTokens(model.Env().ctx.allocator,
Extents2D(model.GetModelConfig().vit_config.seq_len /
(pool_dim * pool_dim),
model.GetModelConfig().model_dim));
HWY_ASSERT(model.Info().wrapping == PromptWrapping::PALIGEMMA ||
model.Info().wrapping == PromptWrapping::GEMMA_VLM);
HWY_ASSERT(image.ReadPPM(args.image_file.path));
HWY_ASSERT(image.ReadPPM(inference.image_file.path));
const size_t image_size = model.GetModelConfig().vit_config.image_size;
image.Resize(image_size, image_size);
RuntimeConfig runtime_config = {
.gen = &gen, .verbosity = app.verbosity, .use_spinning = app.spin};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = inference.verbosity,
.use_spinning = threading.spin};
double image_tokens_start = hwy::platform::Now();
model.GenerateImageTokens(runtime_config, image, image_tokens);
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
double image_tokens_duration = hwy::platform::Now() - image_tokens_start;
fprintf(stderr,
"\n\n[ Timing info ] Image token generation took: %d ms\n",
@ -121,12 +123,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
const bool first_response_token = tokens_generated_this_turn == prompt_size;
++tokens_generated_this_turn;
if (in_prompt) {
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::cerr << "." << std::flush;
}
return true;
} else if (model.GetModelConfig().IsEOS(token)) {
if (app.verbosity >= 2) {
if (inference.verbosity >= 2) {
std::cout << "\n[ End ]\n";
}
return true;
@ -135,7 +137,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
HWY_ASSERT(model.Tokenizer().Decode(std::vector<int>{token}, &token_text));
if (first_response_token) {
token_text.erase(0, token_text.find_first_not_of(" \t\n"));
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::cout << "\n\n";
}
}
@ -147,7 +149,8 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
tokens_generated_this_turn = 0;
// Read prompt and handle special commands.
std::string prompt_string = GetPrompt(std::cin, app.verbosity, eot_line);
std::string prompt_string =
GetPrompt(std::cin, inference.verbosity, inference.eot_line);
if (!std::cin) return;
// If !eot_line.empty(), we append \n, so only look at the first 2 chars.
if (prompt_string.size() >= 2 && prompt_string[0] == '%') {
@ -163,13 +166,12 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}
// Set up runtime config.
TimingInfo timing_info = {.verbosity = app.verbosity};
TimingInfo timing_info = {.verbosity = inference.verbosity};
RuntimeConfig runtime_config = {.gen = &gen,
.verbosity = app.verbosity,
.verbosity = inference.verbosity,
.stream_token = stream_token,
.accept_token = accept_token,
.use_spinning = app.spin};
args.CopyTo(runtime_config);
.use_spinning = threading.spin};
inference.CopyTo(runtime_config);
size_t prefix_end = 0;
std::vector<int> prompt;
@ -197,7 +199,7 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}
// Generate until EOS or max_generated_tokens.
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::cerr << "\n[ Reading prompt ] " << std::flush;
}
model.Generate(runtime_config, prompt, abs_pos, prefix_end, kv_cache,
@ -205,9 +207,10 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
std::cout << "\n\n";
// Prepare for the next turn. Works only for PaliGemma.
if (!args.multiturn || model.Info().wrapping == PromptWrapping::PALIGEMMA) {
if (!inference.multiturn ||
model.Info().wrapping == PromptWrapping::PALIGEMMA) {
abs_pos = 0; // Start a new turn at position 0.
InitGenerator(args, gen);
InitGenerator(inference, gen);
} else {
// The last token was either EOS, then it should be ignored because it is
// never part of the dialog, see Table 5 in the Gemma-2 paper:
@ -223,20 +226,19 @@ void ReplGemma(Gemma& model, KVCache& kv_cache, const AppArgs& app,
}
}
void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
void Run(ThreadingArgs& threading, LoaderArgs& loader,
InferenceArgs& inference) {
PROFILER_ZONE("Run.misc");
// Note that num_threads is an upper bound; we also limit to the number of
// detected and enabled cores.
const BoundedTopology topology = CreateTopology(app);
NestedPools pools = CreatePools(topology, app);
MatMulEnv env(topology, pools);
if (app.verbosity >= 2) env.print_best = true;
MatMulEnv env(MakeMatMulEnv(threading));
if (inference.verbosity >= 2) env.print_best = true;
Gemma model = CreateGemma(loader, env);
KVCache kv_cache =
KVCache::Create(model.GetModelConfig(), inference.prefill_tbatch_size);
if (app.verbosity >= 1) {
if (inference.verbosity >= 1) {
std::string instructions =
"*Usage*\n"
" Enter an instruction and press enter (%C resets conversation, "
@ -259,11 +261,11 @@ void Run(LoaderArgs& loader, InferenceArgs& inference, AppArgs& app) {
std::cout << "\033[2J\033[1;1H" // clear screen
<< kAsciiArtBanner << "\n\n";
ShowConfig(loader, inference, app, topology, pools);
ShowConfig(threading, loader, inference);
std::cout << "\n" << instructions << "\n";
}
ReplGemma(model, kv_cache, app, inference, AcceptFunc(), app.eot_line);
ReplGemma(threading, inference, model, kv_cache);
}
} // namespace gcpp
@ -272,31 +274,29 @@ int main(int argc, char** argv) {
{
PROFILER_ZONE("Startup.misc");
// Placeholder for internal init, do not modify.
gcpp::ThreadingArgs threading(argc, argv);
gcpp::LoaderArgs loader(argc, argv);
gcpp::InferenceArgs inference(argc, argv);
gcpp::AppArgs app(argc, argv);
if (gcpp::HasHelp(argc, argv)) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
gcpp::ShowHelp(threading, loader, inference);
return 0;
}
if (const char* error = loader.Validate()) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
gcpp::ShowHelp(threading, loader, inference);
HWY_ABORT("\nInvalid args: %s", error);
}
if (const char* error = inference.Validate()) {
std::cerr << gcpp::kAsciiArtBanner;
gcpp::ShowHelp(loader, inference, app);
gcpp::ShowHelp(threading, loader, inference);
HWY_ABORT("\nInvalid args: %s", error);
}
gcpp::Run(loader, inference, app);
gcpp::Run(threading, loader, inference);
}
PROFILER_PRINT_RESULTS(); // Must call outside the zone above.
return 0;

View File

@ -562,11 +562,12 @@ TensorIndex::TensorIndex(const ModelConfig& config, int llm_layer_idx,
if (llm_layer_idx < 0 && img_layer_idx < 0) {
tensors_ = ModelTensors(config);
} else if (llm_layer_idx_ < 0 && 0 <= img_layer_idx &&
img_layer_idx < config.vit_config.layer_configs.size()) {
img_layer_idx <
static_cast<int>(config.vit_config.layer_configs.size())) {
const auto& layer_config = config.vit_config.layer_configs[img_layer_idx];
tensors_ = ImageLayerTensors(config, layer_config, img_layer_idx);
} else if (0 <= llm_layer_idx &&
llm_layer_idx < config.layer_configs.size()) {
llm_layer_idx < static_cast<int>(config.layer_configs.size())) {
const auto& layer_config = config.layer_configs[llm_layer_idx];
tensors_ = LLMLayerTensors(config, layer_config, reshape_att);
}

View File

@ -54,6 +54,28 @@ struct TensorInfo {
bool cols_take_extra_dims = false;
};
// Collapses/expands the tensor dims into 2D extents, which may be 0, 0 for
// not-present tensors such as ViT in a text-only model.
static inline Extents2D ExtentsFromInfo(const TensorInfo* tensor) {
if (tensor == nullptr) return Extents2D(0, 0);
size_t cols = tensor->shape.back();
size_t rows = 1;
if (tensor->cols_take_extra_dims) {
rows = tensor->shape[0];
for (size_t i = 1; i < tensor->shape.size() - 1; ++i) {
cols *= tensor->shape[i];
}
} else { // rows take extra dims
for (size_t i = 0; i < tensor->shape.size() - 1; ++i) {
rows *= tensor->shape[i];
}
}
// Sometimes only one of rows or cols is zero; set both for consistency.
if (rows == 0 || cols == 0) rows = cols = 0;
return Extents2D(rows, cols);
}
// Universal index of tensor information, which can be built for a specific
// layer_idx.
class TensorIndex {
@ -96,6 +118,16 @@ class TensorIndex {
std::unordered_map<std::string, size_t> name_map_;
};
static inline TensorIndex TensorIndexLLM(const ModelConfig& config,
size_t llm_layer_idx) {
return TensorIndex(config, static_cast<int>(llm_layer_idx), -1, false);
}
static inline TensorIndex TensorIndexImg(const ModelConfig& config,
size_t img_layer_idx) {
return TensorIndex(config, -1, static_cast<int>(img_layer_idx), false);
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_TENSOR_INDEX_H_

View File

@ -29,6 +29,7 @@
#include "compression/shared.h"
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h" // HWY_ABORT
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -118,7 +119,7 @@ struct TensorSaver {
weights.ForEachTensor(
{&weights}, fet,
[&writer](const char* name, hwy::Span<MatPtr*> tensors) {
tensors[0]->CallUpcasted(writer, name);
CallUpcasted(tensors[0]->GetType(), tensors[0], writer, name);
});
}
};
@ -155,11 +156,11 @@ class WeightInitializer {
WeightInitializer(std::mt19937& gen) : dist_(0.0f, 1.0f), gen_(gen) {}
void operator()(const char* name, hwy::Span<MatPtr*> tensors) {
float* data = tensors[0]->data<float>();
for (size_t i = 0; i < tensors[0]->NumElements(); ++i) {
float* data = tensors[0]->RowT<float>(0);
for (size_t i = 0; i < tensors[0]->Extents().Area(); ++i) {
data[i] = dist_(gen_);
}
tensors[0]->set_scale(1.0f);
tensors[0]->SetScale(1.0f);
}
private:
@ -226,11 +227,11 @@ void ModelWeightsStorage::LogWeightStats() {
{float_weights_.get()}, ForEachType::kInitNoToc,
[&total_weights](const char* name, hwy::Span<MatPtr*> tensors) {
const MatPtr& tensor = *tensors[0];
if (tensor.scale() != 1.0f) {
printf("[scale=%f] ", tensor.scale());
if (tensor.Scale() != 1.0f) {
printf("[scale=%f] ", tensor.Scale());
}
LogVec(name, tensor.data<float>(), tensor.NumElements());
total_weights += tensor.NumElements();
LogVec(name, tensor.RowT<float>(0), tensor.Extents().Area());
total_weights += tensor.Extents().Area();
});
printf("%-20s %12zu\n", "Total", total_weights);
}
@ -258,8 +259,8 @@ void ModelWeightsStorage::CreateForType(Type weight_type,
}
template <>
void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
if (attn_vec_einsum_w.data() == nullptr) return;
void LayerWeightsPtrs<NuqStream>::Reshape(MatOwner* storage) {
if (!attn_vec_einsum_w.HasPtr()) return;
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
@ -267,8 +268,7 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
if (storage != nullptr) {
storage->Allocate();
att_weights.SetPtr(*storage);
storage->AllocateFor(att_weights, MatPadding::kPacked);
}
const hwy::HWY_NAMESPACE::ScalableTag<float> df;
@ -279,7 +279,7 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
hwy::AllocateAligned<float>(model_dim * heads * qkv_dim);
HWY_NAMESPACE::DecompressAndZeroPad(
df, MakeSpan(attn_vec_einsum_w.data(), model_dim * heads * qkv_dim), 0,
df, MakeSpan(attn_vec_einsum_w.Packed(), model_dim * heads * qkv_dim), 0,
attn_vec_einsum_w_tmp.get(), model_dim * heads * qkv_dim);
for (size_t m = 0; m < model_dim; ++m) {
@ -296,10 +296,10 @@ void LayerWeightsPtrs<NuqStream>::Reshape(MatStorage* storage) {
HWY_NAMESPACE::Compress(
att_weights_tmp.get(), model_dim * heads * qkv_dim, work,
MakeSpan(att_weights.data(), model_dim * heads * qkv_dim),
MakeSpan(att_weights.Packed(), model_dim * heads * qkv_dim),
/*packed_ofs=*/0, pool);
att_weights.set_scale(attn_vec_einsum_w.scale());
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
} // namespace gcpp

View File

@ -31,12 +31,32 @@
#include "gemma/common.h"
#include "gemma/configs.h"
#include "gemma/tensor_index.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace gcpp {
static inline std::string CacheName(const MatPtr& mat, int layer = -1,
char separator = ' ', int index = -1) {
// Already used/retired: s, S, n, 1
const char prefix = mat.GetType() == Type::kF32 ? 'F'
: mat.GetType() == Type::kBF16 ? 'B'
: mat.GetType() == Type::kSFP ? '$'
: mat.GetType() == Type::kNUQ ? '2'
: '?';
std::string name = std::string(1, prefix) + mat.Name();
if (layer >= 0 || index >= 0) {
name += '_';
if (layer >= 0) name += std::to_string(layer);
if (index >= 0) {
name += separator + std::to_string(index);
}
}
return name;
}
// Different tensors need to appear in a ForEachTensor, according to what is
// happening.
enum class ForEachType {
@ -181,10 +201,10 @@ struct LayerWeightsPtrs {
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
// after loading weights via ForEachTensor.
// TODO: update compression/convert_weights to bake this in.
void Reshape(MatStorage* storage) {
void Reshape(MatOwner* storage) {
static_assert(!hwy::IsSame<Weight, NuqStream>());
if (attn_vec_einsum_w.data() == nullptr) return;
if (!attn_vec_einsum_w.HasPtr()) return;
const size_t model_dim = layer_config.model_dim;
const size_t heads = layer_config.heads;
@ -192,33 +212,33 @@ struct LayerWeightsPtrs {
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
if (storage != nullptr) {
storage->Allocate();
att_weights.SetPtr(*storage);
storage->AllocateFor(att_weights, MatPadding::kPacked);
}
for (size_t m = 0; m < model_dim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * heads * qkv_dim;
Weight* HWY_RESTRICT out_row =
att_weights.template RowT<Weight>(0) + m * heads * qkv_dim;
for (size_t h = 0; h < heads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
hwy::CopyBytes(attn_vec_einsum_w.template RowT<Weight>(0) +
h * model_dim * qkv_dim + m * qkv_dim,
out_row + h * qkv_dim, qkv_dim * sizeof(Weight));
}
}
att_weights.set_scale(attn_vec_einsum_w.scale());
att_weights.SetScale(attn_vec_einsum_w.Scale());
}
ArrayT<WeightF32OrBF16> key_norm_scale;
ArrayT<WeightF32OrBF16> query_norm_scale;
// Used by ForEachTensor for per-layer tensors.
#define GEMMA_CALL_FUNC(member) \
{ \
for (int i = 0; i < ptrs.size(); ++i) { \
tensors[i] = &ptrs[i]->member; \
} \
if (tensors[0]->Ptr() != nullptr || fet != ForEachType::kIgnoreNulls) { \
func(ptrs[0]->member.CacheName(layer_idx, sep, sep_index).c_str(), \
hwy::Span<MatPtr*>(tensors.data(), ptrs.size())); \
} \
#define GEMMA_CALL_FUNC(member) \
{ \
for (int i = 0; i < ptrs.size(); ++i) { \
tensors[i] = &ptrs[i]->member; \
} \
if (tensors[0]->HasPtr() || fet != ForEachType::kIgnoreNulls) { \
func(CacheName(ptrs[0]->member, layer_idx, sep, sep_index).c_str(), \
hwy::Span<MatPtr*>(tensors.data(), ptrs.size())); \
} \
}
template <class Func>
@ -307,19 +327,18 @@ struct LayerWeightsPtrs {
void ZeroInit(int layer_idx) {
ForEachTensor({this}, layer_idx, ForEachType::kIgnoreNulls,
[](const char*, hwy::Span<MatPtr*> tensors) {
tensors[0]->ZeroInit();
gcpp::ZeroInit(*tensors[0]);
});
}
// Allocates memory for all the tensors in the layer.
// Note that this is slow and only used for a stand-alone layer.
void Allocate(std::vector<MatStorage>& layer_storage) {
void Allocate(std::vector<MatOwner>& layer_storage) {
ForEachTensor(
{this}, /*layer_idx=*/0, ForEachType::kInitNoToc,
[&layer_storage](const char* name, hwy::Span<MatPtr*> tensors) {
layer_storage.emplace_back(*tensors[0]);
layer_storage.back().Allocate();
tensors[0]->SetPtr(layer_storage.back());
layer_storage.push_back(MatOwner());
layer_storage.back().AllocateFor(*tensors[0], MatPadding::kPacked);
});
}
};
@ -393,11 +412,9 @@ struct ModelWeightsPtrs {
// Called by weights.cc after Loading, before att_w has been allocated.
void AllocAndCopyWithTranspose(hwy::ThreadPool& pool,
std::vector<MatStorage>& model_storage) {
std::vector<MatOwner>& model_storage) {
size_t storage_index = model_storage.size();
for (auto& layer : c_layers) {
model_storage.emplace_back(layer.att_weights);
}
model_storage.resize(model_storage.size() + c_layers.size());
pool.Run(0, c_layers.size(),
[this, &model_storage, storage_index](uint64_t layer,
size_t /*thread*/) {
@ -412,8 +429,8 @@ struct ModelWeightsPtrs {
}
void ZeroInit() {
embedder_input_embedding.ZeroInit();
final_norm_scale.ZeroInit();
gcpp::ZeroInit(embedder_input_embedding);
gcpp::ZeroInit(final_norm_scale);
for (size_t i = 0; i < c_layers.size(); ++i) {
c_layers[i].ZeroInit(i);
}
@ -430,21 +447,21 @@ struct ModelWeightsPtrs {
return &vit_layers[layer];
}
void Allocate(std::vector<MatStorage>& model_storage, hwy::ThreadPool& pool) {
void Allocate(std::vector<MatOwner>& model_storage, hwy::ThreadPool& pool) {
std::vector<MatPtr*> model_toc;
ForEachTensor(
{this}, ForEachType::kInitNoToc,
[&model_toc, &model_storage](const char*, hwy::Span<MatPtr*> tensors) {
model_toc.push_back(tensors[0]);
model_storage.emplace_back(*tensors[0]);
model_storage.push_back(MatOwner());
});
// Allocate in parallel using the pool.
pool.Run(0, model_toc.size(),
[&model_toc, &model_storage](uint64_t task, size_t /*thread*/) {
// model_storage may have had content before we started.
size_t idx = task + model_storage.size() - model_toc.size();
model_storage[idx].Allocate();
model_toc[task]->SetPtr(model_storage[idx]);
model_storage[idx].AllocateFor(*model_toc[task],
MatPadding::kPacked);
});
}
@ -453,8 +470,7 @@ struct ModelWeightsPtrs {
ForEachTensor({this, const_cast<ModelWeightsPtrs<Weight>*>(&other)},
ForEachType::kIgnoreNulls,
[](const char*, hwy::Span<MatPtr*> tensors) {
hwy::CopyBytes(tensors[1]->Ptr(), tensors[0]->Ptr(),
tensors[1]->SizeBytes());
CopyMat(*tensors[1], *tensors[0]);
});
}
@ -467,10 +483,10 @@ struct ModelWeightsPtrs {
[&scales, &scale_pos, this](const char*, hwy::Span<MatPtr*> tensors) {
if (this->scale_names.count(tensors[0]->Name())) {
if (scale_pos < scales.size()) {
tensors[0]->set_scale(scales[scale_pos]);
tensors[0]->SetScale(scales[scale_pos]);
} else {
float scale = ScaleWeights(tensors[0]->data<float>(),
tensors[0]->NumElements());
float scale = ScaleWeights(tensors[0]->RowT<float>(0),
tensors[0]->Extents().Area());
scales.push_back(scale);
}
++scale_pos;
@ -615,9 +631,9 @@ class ModelWeightsStorage {
std::unique_ptr<ModelWeightsPtrs<SfpStream>> sfp_weights_;
std::unique_ptr<ModelWeightsPtrs<NuqStream>> nuq_weights_;
// Storage for all the matrices and vectors.
std::vector<MatStorage> model_storage_;
std::vector<MatOwner> model_storage_;
};
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_
#endif // THIRD_PARTY_GEMMA_CPP_GEMMA_WEIGHTS_H_

View File

@ -31,15 +31,12 @@
#include <stdio.h>
#include <algorithm>
#include <memory>
#include <vector>
#include "compression/compress.h"
#include "compression/shared.h"
#include "ops/matmul.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/nanobenchmark.h"
#include "hwy/profiler.h"
@ -53,8 +50,8 @@
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "compression/test_util-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -63,59 +60,6 @@ extern int64_t first_target;
namespace HWY_NAMESPACE {
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT>
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStoragePtr<MatT> GenerateMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content);
const float scale =
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
mat->set_scale(0.6f); // Arbitrary value, different from 1.
return mat;
}
// extents describes the transposed matrix.
template <typename MatT>
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
const float scale =
SfpStream::kMax / (mat->NumElements() + hwy::Unpredictable1() - 1);
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
// Arbitrary value, different from 1, must match GenerateMat.
mat->set_scale(0.6f);
return mat;
}
void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
std::vector<double>& times, MMPerKey* per_key) {
std::sort(times.begin(), times.end());
@ -135,7 +79,8 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
// M = A rows, K = A cols, N = C cols.
template <typename TA, typename TB = TA, typename TC = float>
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
const Allocator2& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool(0);
if (env.print_config || env.print_measurement) {
fprintf(stderr, "\n");
}
@ -147,24 +92,23 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
const Extents2D B_extents(N, K); // already transposed
const Extents2D C_extents(M, N);
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<TC> c_slow_batch =
AllocateAlignedRows<TC>(allocator, C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
std::unique_ptr<MatStorageT<float>> add_storage;
MatStorageT<float> add_storage("add", Extents2D(), MatPadding::kPacked);
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, N), pool);
HWY_ASSERT(add_storage);
add_storage->set_scale(1.0f);
add_storage.SetScale(1.0f);
}
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
HWY_ASSERT(a && b_trans);
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
MatStorageT<TA> a = GenerateMat<TA>(A_extents, pool);
MatStorageT<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
const auto A = ConstMatFromWeights(a);
const auto B = ConstMatFromWeights(b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
// Fewer reps for large batch sizes, which take longer.
const size_t num_samples = M < 32 ? 20 : 12;
@ -174,11 +118,11 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
// Ensure usage conditions are set before autotuning. Both binding and
// spinning may materially affect the choice of config. No harm in calling
// BindB/C if there is a single package: they will be a no-op.
BindB(B_extents.rows, sizeof(TC), B, env.parallel);
BindC(A_extents.rows, C, env.parallel);
BindB(allocator, B_extents.rows, sizeof(TC), B, env.parallel);
BindC(allocator, A_extents.rows, C, env.parallel);
Tristate use_spinning = Tristate::kDefault;
env.parallel.Pools().MaybeStartSpinning(use_spinning);
env.ctx.pools.MaybeStartSpinning(use_spinning);
// env.print_config = true;
// env.print_measurement = true;
@ -198,7 +142,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
if (per_key->autotune.Best()) times.push_back(elapsed);
}
hwy::PreventElision(keep);
env.parallel.Pools().MaybeStopSpinning(use_spinning);
env.ctx.pools.MaybeStopSpinning(use_spinning);
PrintSpeed(A_extents, B_extents, times, per_key);
}
@ -216,17 +160,11 @@ void BenchAllMatMul() {
return;
}
const size_t max_threads = 0; // no limit
const BoundedSlice package_slice; // all packages/sockets
const BoundedSlice cluster_slice; // all clusters/CCX
const BoundedSlice lp_slice; // default to all cores (per package).
const BoundedTopology topology(package_slice, cluster_slice, lp_slice);
Allocator::Init(topology, /*enable_bind=*/true);
NestedPools pools(topology, max_threads, Tristate::kDefault);
fprintf(stderr, "BenchAllMatMul %s %s\n", topology.TopologyString(),
pools.PinString());
ThreadingContext2& ctx = ThreadingContext2::Get();
fprintf(stderr, "BenchAllMatMul %s %s\n", ctx.topology.TopologyString(),
ctx.pools.PinString());
MatMulEnv env(topology, pools);
MatMulEnv env(ctx);
for (size_t batch_size : {1, 4, 128, 512}) {
constexpr bool kAdd = false;

View File

@ -16,6 +16,7 @@
#include <stddef.h>
#include "compression/compress.h"
#include "util/mat.h"
#include "hwy/base.h"
#include "hwy/profiler.h"
@ -379,10 +380,7 @@ template <typename MatT, typename VT>
HWY_INLINE float Dot(const MatPtrT<MatT>& w, size_t w_ofs,
const VT* vec_aligned, size_t num) {
const hn::ScalableTag<VT> d;
return w.scale() * Dot(d,
MakeConstSpan(reinterpret_cast<const MatT*>(w.Ptr()),
w.NumElements()),
w_ofs, vec_aligned, num);
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -28,9 +28,8 @@
#include "compression/shared.h"
#include "util/allocator.h"
#include "util/app.h"
#include "util/test_util.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
@ -805,7 +804,7 @@ class DotStats {
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f);
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4);
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.8E-4);
ASSERT_INSIDE(kPairwise, 4.5E-4, s_l1s[kPairwise].Mean(), 4E-3);
ASSERT_INSIDE(kPairwise, 1.1E-3f, s_l1s[kPairwise].Max(), 1E-2f);
@ -1000,9 +999,7 @@ struct TestShortDotsT {
const size_t N = hn::Lanes(d);
const hn::ScalableTag<float> df; // for CallDot
const AppArgs app;
BoundedTopology topology(CreateTopology(app));
NestedPools pools = CreatePools(topology, app);
const Allocator2& allocator = gcpp::ThreadingContext2::Get().allocator;
CompressWorkingSet work;
std::mt19937 rng;
rng.seed(12345);
@ -1014,14 +1011,14 @@ struct TestShortDotsT {
// hence they require padding to one vector.
const size_t padded_num = hwy::RoundUpTo(num, N);
const size_t packed_num = CompressedArrayElements<Packed>(num);
RowVectorBatch<float> raw_w(Extents2D(1, padded_num));
RowVectorBatch<float> raw_v(Extents2D(1, padded_num));
RowVectorBatch<Packed> weights(Extents2D(1, packed_num));
RowVectorBatch<float> raw_w(allocator, Extents2D(1, padded_num));
RowVectorBatch<float> raw_v(allocator, Extents2D(1, padded_num));
RowVectorBatch<Packed> weights(allocator, Extents2D(1, packed_num));
const PackedSpan<Packed> w(weights.Batch(0), packed_num);
RowVectorBatch<T> vectors(Extents2D(1, num));
RowVectorBatch<T> vectors(allocator, Extents2D(1, num));
const PackedSpan<T> v(vectors.Batch(0), num);
RowVectorBatch<double> bufs(Extents2D(1, num));
RowVectorBatch<double> bufs(allocator, Extents2D(1, num));
double* HWY_RESTRICT buf = bufs.Batch(0);
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
@ -1099,10 +1096,21 @@ void TestAllDot() {
return;
}
constexpr size_t kMaxWorkers = 15;
// Reset with cap on workers because we only support `kMaxWorkers`.
ThreadingContext2::ThreadHostileInvalidate();
ThreadingArgs threading_args;
threading_args.max_packages = 1;
threading_args.max_clusters = 1;
threading_args.max_lps = kMaxWorkers - 1;
ThreadingContext2::SetArgs(threading_args);
ThreadingContext2& ctx = ThreadingContext2::Get();
const Allocator2& allocator = ctx.allocator;
{ // ensure no profiler zones are active
const hn::ScalableTag<float> df;
constexpr size_t kMaxWorkers = 15;
std::mt19937 rngs[kMaxWorkers];
for (size_t i = 0; i < kMaxWorkers; ++i) {
rngs[i].seed(12345 + 65537 * i);
@ -1110,44 +1118,43 @@ void TestAllDot() {
constexpr size_t kReps = hn::AdjustedReps(40);
const size_t num = 24 * 1024;
const BoundedTopology topology(BoundedSlice(0, 1), BoundedSlice(0, 1),
BoundedSlice());
NestedPools pools(topology, kMaxWorkers - 1, /*pin=*/Tristate::kDefault);
RowVectorBatch<float> a(Extents2D(kMaxWorkers, num));
RowVectorBatch<float> b(Extents2D(kMaxWorkers, num));
RowVectorBatch<double> bufs(Extents2D(kMaxWorkers, num));
RowVectorBatch<float> a(allocator, Extents2D(kMaxWorkers, num));
RowVectorBatch<float> b(allocator, Extents2D(kMaxWorkers, num));
RowVectorBatch<double> bufs(allocator, Extents2D(kMaxWorkers, num));
std::array<DotStats, kMaxWorkers> all_stats;
pools.Cluster(0, 0).Run(0, kReps, [&](const uint32_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Batch(thread);
float* HWY_RESTRICT pb = b.Batch(thread);
double* HWY_RESTRICT buf = bufs.Batch(thread);
const PackedSpan<const float> a_span(pa, num);
DotStats& stats = all_stats[thread];
const double cond =
GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
ctx.pools.Cluster(0, 0).Run(
0, kReps, [&](const uint32_t rep, size_t thread) {
float* HWY_RESTRICT pa = a.Batch(thread);
float* HWY_RESTRICT pb = b.Batch(thread);
double* HWY_RESTRICT buf = bufs.Batch(thread);
const PackedSpan<const float> a_span(pa, num);
DotStats& stats = all_stats[thread];
const double cond =
GenerateIllConditionedInputs(num, pa, pb, rngs[thread]);
const float dot_exact = ExactDot(pa, pb, num, buf);
const float dot_exact = ExactDot(pa, pb, num, buf);
float dots[kVariants] = {};
double times[kVariants] = {};
for (size_t variant = 0; variant < kVariants; ++variant) {
constexpr size_t kTimeReps = hn::AdjustedReps(10);
std::array<double, kTimeReps> elapsed;
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
const double start = hwy::platform::Now();
dots[variant] += CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
hwy::PreventElision(*pa);
elapsed[time_rep] = hwy::platform::Now() - start;
}
dots[variant] /= kTimeReps;
times[variant] = TrimmedMean(elapsed.data(), kTimeReps);
}
float dots[kVariants] = {};
double times[kVariants] = {};
for (size_t variant = 0; variant < kVariants; ++variant) {
constexpr size_t kTimeReps = hn::AdjustedReps(10);
std::array<double, kTimeReps> elapsed;
for (int time_rep = 0; time_rep < kTimeReps; ++time_rep) {
const double start = hwy::platform::Now();
dots[variant] +=
CallDot(df, variant, a_span, /*w_ofs=*/0, pb, num);
hwy::PreventElision(*pa);
elapsed[time_rep] = hwy::platform::Now() - start;
}
dots[variant] /= kTimeReps;
times[variant] = TrimmedMean(elapsed.data(), kTimeReps);
}
stats.NotifyTimes(times);
stats.NotifyRep(num, cond, dot_exact, dots);
stats.NotifyRatios();
});
stats.NotifyTimes(times);
stats.NotifyRep(num, cond, dot_exact, dots);
stats.NotifyRatios();
});
DotStats& stats = all_stats[0];
for (size_t i = 1; i < kMaxWorkers; ++i) {

View File

@ -25,7 +25,7 @@
#include <cmath> // std::abs
#include <memory>
#include "compression/compress.h"
#include "util/mat.h"
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -37,6 +37,7 @@
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/matvec-inl.h"
#include "hwy/tests/test_util-inl.h"
@ -48,18 +49,18 @@ using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
FloatPtr SimpleMatVecAdd(const MatStorageT<float>& mat, const FloatPtr& vec,
const FloatPtr& add) {
FloatPtr raw_mat = hwy::AllocateAligned<float>(mat.NumElements());
const size_t num = mat.Rows() * mat.Cols();
FloatPtr raw_mat = hwy::AllocateAligned<float>(num);
FloatPtr out = hwy::AllocateAligned<float>(mat.Rows());
HWY_ASSERT(raw_mat && out);
const hn::ScalableTag<float> df;
DecompressAndZeroPad(df, MakeSpan(mat.data(), mat.NumElements()), 0,
raw_mat.get(), mat.NumElements());
DecompressAndZeroPad(df, mat.Span(), 0, raw_mat.get(), num);
for (size_t idx_row = 0; idx_row < mat.Rows(); idx_row++) {
out[idx_row] = 0.0f;
for (size_t idx_col = 0; idx_col < mat.Cols(); idx_col++) {
out[idx_row] += raw_mat[mat.Cols() * idx_row + idx_col] * vec[idx_col];
}
out[idx_row] *= mat.scale();
out[idx_row] *= mat.Scale();
out[idx_row] += add[idx_row];
}
return out;
@ -69,8 +70,10 @@ template <typename MatT, size_t kOuter, size_t kInner>
std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat = std::make_unique<MatStorageT<float>>("TestMat", kOuter, kInner);
FloatPtr raw_mat = hwy::AllocateAligned<float>(mat->NumElements());
const Extents2D extents(kOuter, kInner);
auto mat = std::make_unique<MatStorageT<float>>("TestMat", extents,
MatPadding::kPacked);
FloatPtr raw_mat = hwy::AllocateAligned<float>(extents.Area());
HWY_ASSERT(raw_mat);
const float scale = 1.0f / kInner;
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
@ -80,8 +83,8 @@ std::unique_ptr<MatStorageT<float>> GenerateMat(size_t offset,
}
});
CompressScaled(raw_mat.get(), mat->NumElements(), ws, *mat, pool);
mat->set_scale(1.9f); // Arbitrary value, different from 1.
CompressScaled(raw_mat.get(), extents.Area(), ws, *mat, pool);
mat->SetScale(1.9f); // Arbitrary value, different from 1.
return mat;
}

View File

@ -15,6 +15,7 @@
#include <stddef.h>
#include <stdint.h>
#include <stdio.h>
#include <vector>
@ -22,9 +23,8 @@
#include "ops/matmul.h" // IWYU pragma: export
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
#include "hwy/timer.h"
@ -866,6 +866,8 @@ class MMPerPackage {
const IndexRange& range_np)
: args_(args),
pkg_idx_(pkg_idx),
// May be overwritten with a view of A, if already BF16.
A_(args_.env->storage.A(args.env->ctx.allocator, pkg_idx, A.Extents())),
range_np_(range_np),
mr_(config.MR()),
ranges_mc_(config.RangesOfMC(A.Extents().rows)),
@ -873,14 +875,11 @@ class MMPerPackage {
ranges_nc_(config.RangesOfNC(range_np)),
order_(config.Order()),
inner_tasks_(config.InnerTasks()),
out_(config.Out()) {
// May be overwritten with a view of A, if already BF16.
A_ = args_.env->storage.A(pkg_idx, A.Extents());
{
MMZone zone;
zone.MaybeEnter("MM.DecompressA", args_);
A_ = DecompressA(A);
}
out_(config.Out()),
line_bytes_(args.env->ctx.allocator.LineBytes()) {
MMZone zone;
zone.MaybeEnter("MM.DecompressA", args_);
A_ = DecompressA(A);
}
// B is decompressed several call layers lower, but not all member functions
@ -909,14 +908,14 @@ class MMPerPackage {
// Compute size of per-worker storage for `kNR` row ranges of B. Stack
// allocation avoids passing a worker index.
static constexpr size_t B_stride_max_ =
StrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
MaxStrideForCyclicOffsets<BF16>(MMStorage::kMaxKC);
static constexpr size_t B_storage_max_ =
kNR * B_stride_max_ + Allocator::MaxQuantumBytes() / sizeof(BF16);
kNR * B_stride_max_ + Allocator2::MaxQuantum<BF16>();
// Granularity of `ForNP`. B rows produce C columns, so we
// want a multiple of the line size to prevent false sharing.
static size_t MultipleNP(size_t sizeof_TC) {
return HWY_MAX(kNR, Allocator::LineBytes() / sizeof_TC);
size_t MultipleNP(size_t sizeof_TC) const {
return HWY_MAX(kNR, line_bytes_ / sizeof_TC);
}
// Single M and K, parallel N. Fills all of C directly.
@ -931,14 +930,16 @@ class MMPerPackage {
const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num();
const RowPtrBF& A_view = A_.View(range_M.begin(), 0, K);
const size_t B_stride = StrideForCyclicOffsets<BF16>(K);
const size_t B_stride =
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
// Similar to `loop_nc` below, but here we hoisted `A_view`.
args_.env->parallel.ForNP(
range_np_, MultipleNP(sizeof(TC)), inner_tasks_, pkg_idx_,
[&](const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, K, B_stride);
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
@ -972,7 +973,9 @@ class MMPerPackage {
auto out_tag) HWY_ATTR {
const size_t kc = range_kc.Num();
const RowPtrBF& A_view = A_.View(range_mc.begin(), range_kc.begin(), kc);
const RowPtrBF B_view(B_storage, kc, StrideForCyclicOffsets<BF16>(kc));
const RowPtrBF B_view(
args_.env->ctx.allocator, B_storage, kc,
StrideForCyclicOffsets(kc, args_.env->ctx.allocator.Quantum<BF16>()));
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
@ -1027,7 +1030,8 @@ class MMPerPackage {
HWY_DASSERT(ranges_kc_.NumTasks() == 1);
const IndexRange& range_K = ranges_kc_.Range(0);
const size_t K = range_K.Num();
const size_t B_stride = StrideForCyclicOffsets<BF16>(K);
const size_t B_stride =
StrideForCyclicOffsets(K, args_.env->ctx.allocator.Quantum<BF16>());
// Sequential loop over NC/MC/KC, similar to `loop_nc` below
// except for the profiler strings and `out_tag`.
@ -1036,7 +1040,8 @@ class MMPerPackage {
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
const RowPtrBF& A_view = A_.View(range_mc.begin(), 0, K);
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, K, B_stride);
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, K,
B_stride);
for (size_t row_b = range_nc.begin(); row_b < range_nc.end();
row_b += kNR) {
@ -1062,7 +1067,8 @@ class MMPerPackage {
zone.MaybeEnter("MM.NT_MT_K", args_);
const size_t kc_max = ranges_kc_.TaskSize();
HWY_DASSERT(kc_max <= MMStorage::kMaxKC);
const size_t B_stride = StrideForCyclicOffsets<BF16>(kc_max);
const size_t B_stride = StrideForCyclicOffsets(
kc_max, args_.env->ctx.allocator.Quantum<BF16>());
// Sequential loop over NC/MC/KC, for when the M/N loops are
// already parallel. This is B3A2C0 in MOMMS terminology: we read
// `mc x kc` of A, `nc x kc` of B, update `mc x nc` of `partial`.
@ -1088,7 +1094,8 @@ class MMPerPackage {
ranges_mc_, ranges_nc_, pkg_idx_,
[&](const IndexRange& range_mc, const IndexRange& range_nc) HWY_ATTR {
HWY_ALIGN BF16 B_storage[B_storage_max_]; // TLS
const RowPtrBF B_view(B_storage, kc_max, B_stride);
const RowPtrBF B_view(args_.env->ctx.allocator, B_storage, kc_max,
B_stride);
// Peel off the first iteration of the kc loop: avoid
// zero-initializing `partial` by writing into it.
@ -1151,8 +1158,7 @@ class MMPerPackage {
// At least one vector, otherwise DecompressAndZeroPad will add
// padding, which might overwrite neighboring tasks. Also a whole cache
// line to avoid false sharing.
const size_t multiple_K =
HWY_MAX(NBF, Allocator::LineBytes() / sizeof(BF16));
const size_t multiple_K = HWY_MAX(NBF, line_bytes_ / sizeof(BF16));
args_.env->parallel.ForNP(
all_K, multiple_K, inner_tasks, pkg_idx_,
@ -1170,6 +1176,7 @@ class MMPerPackage {
// Autotuning wrapper for `DoDecompressA`.
template <typename TA>
HWY_INLINE RowPtrBF DecompressA(const ConstMat<TA>& A) const {
const Allocator2& allocator = args_.env->ctx.allocator;
MMAutoTune<MMParA>& autotune = args_.per_key->autotune_par_a[pkg_idx_];
// If already BF16, maybe return a view:
if constexpr (hwy::IsSame<TA, BF16>()) {
@ -1177,7 +1184,8 @@ class MMPerPackage {
const size_t NBF = hn::Lanes(hn::ScalableTag<BF16>());
if (HWY_LIKELY(A.extents.cols % NBF == 0)) {
const BF16* pos = A.ptr + A.Row(0);
return RowPtrBF(const_cast<BF16*>(pos), A.extents.cols, A.Stride());
return RowPtrBF(allocator, const_cast<BF16*>(pos), A.extents.cols,
A.Stride());
}
}
@ -1251,6 +1259,7 @@ class MMPerPackage {
const MMOrder order_;
const size_t inner_tasks_;
const MMOut out_;
const size_t line_bytes_;
}; // MMPerPackage
// Stateless, wraps member functions.
@ -1308,6 +1317,7 @@ template <typename TA, typename TB, typename TC>
HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
const float* HWY_RESTRICT add, MatMulEnv& env,
const RowPtr<TC>& C) {
const Allocator2& allocator = env.ctx.allocator;
const size_t M = A.Extents().rows;
const size_t K = A.Extents().cols;
const size_t N = B.Extents().rows;
@ -1315,11 +1325,11 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
intptr_t index = MMImpl::IndexOfKey(key, env.keys);
// First time we see this shape/key.
if (HWY_UNLIKELY(index < 0)) {
env.keys.Append(key);
env.keys.Append(key, allocator);
size_t max_packages = MMParallel::kMaxPackages;
// For low-batch, multiple sockets only help if binding is enabled.
if (!Allocator::ShouldBind() && M <= 4) {
if (!allocator.ShouldBind() && M <= 4) {
max_packages = 1;
}
@ -1351,8 +1361,9 @@ HWY_NOINLINE MMPerKey* MatMul(const ConstMat<TA>& A, const ConstMat<TB>& B,
HWY_ASSERT(N % kNR == 0);
// Negligible CPU time.
tuner.SetCandidates(MMCandidates(M, K, N, sizeof(TC), MMKernel::kMaxMR, kNR,
per_key.ranges_np, env.print_config));
tuner.SetCandidates(MMCandidates(allocator, M, K, N, sizeof(TC),
MMKernel::kMaxMR, kNR, per_key.ranges_np,
env.print_config));
}
const MMConfig& cfg = tuner.NextConfig();

View File

@ -60,10 +60,11 @@ size_t PrevDivisor(const size_t begin, const size_t end, const size_t dim,
// and holds most of their arguments in member variables.
class GenerateCandidates {
public:
GenerateCandidates(size_t M, size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
GenerateCandidates(const Allocator2& allocator, size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np, bool print_config)
: M_(M),
: allocator_(allocator),
M_(M),
K_(K),
N_(N),
sizeof_TC_(sizeof_TC),
@ -73,8 +74,8 @@ class GenerateCandidates {
// `RangesOf*`. Must be a vector multiple. The previous/next cache line
// is likely still in L1, but we expect K > 1000 and might as well round
// up to the line size.
kc_multiple_(HWY_MIN(K, Allocator::LineBytes() / sizeof(BF16))),
nc_multiple_(Allocator::StepBytes() / sizeof_TC),
kc_multiple_(HWY_MIN(K, allocator.LineBytes() / sizeof(BF16))),
nc_multiple_(allocator.StepBytes() / sizeof_TC),
ranges_np_(ranges_np),
print_config_(print_config) {}
@ -172,7 +173,7 @@ class GenerateCandidates {
// subtract the output and buf, and allow using more than the actual L1
// size. This results in an overestimate, and the loop below will propose
// the next few smaller values for the autotuner to evaluate.
const size_t bytes_ab = Allocator::L1Bytes() * 3;
const size_t bytes_ab = allocator_.L1Bytes() * 3;
const size_t col_bytes = rows_a * sizeof(BF16) + nr_ * sizeof(BF16);
size_t kc_max = hwy::DivCeil(bytes_ab, col_bytes);
kc_max =
@ -220,8 +221,8 @@ class GenerateCandidates {
// packed B. We want `mc * kc` elements of A to fit in L2, alongside
// `bytes_b` plus `mc` cache lines because resident-A updates `mc` rows of
// partial.
const size_t bytes_per_mc = kc * sizeof(BF16) + Allocator::LineBytes();
size_t mc_max = hwy::DivCeil(Allocator::L2Bytes() - bytes_b, bytes_per_mc);
const size_t bytes_per_mc = kc * sizeof(BF16) + allocator_.LineBytes();
size_t mc_max = hwy::DivCeil(allocator_.L2Bytes() - bytes_b, bytes_per_mc);
mc_max = HWY_MIN(mc_max, MMStorage::kMaxM);
HWY_DASSERT(mc_max != 0);
mc_max = HWY_MIN(mc_max, M_);
@ -264,7 +265,7 @@ class GenerateCandidates {
// Otherwise, leave it unbounded.
if (M_ > mr) {
const size_t bytes_per_nc = (kc * sizeof(BF16) + mc * out_bytes);
nc_max = hwy::DivCeil(Allocator::L3Bytes(), bytes_per_nc);
nc_max = hwy::DivCeil(allocator_.L3Bytes(), bytes_per_nc);
nc_max = HWY_MIN(HWY_MIN(nc_max, MMStorage::kMaxN), np_max);
}
HWY_DASSERT(nc_max != 0);
@ -351,6 +352,7 @@ class GenerateCandidates {
}
}
const Allocator2& allocator_;
const size_t M_;
const size_t K_;
const size_t N_;
@ -370,25 +372,26 @@ class GenerateCandidates {
} // namespace
// Facade to avoid exposing `GenerateCandidates` in the header.
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config) {
return GenerateCandidates(M, K, N, sizeof_TC, max_mr, nr, ranges_np,
print_config)();
return GenerateCandidates(allocator, M, K, N, sizeof_TC, max_mr, nr,
ranges_np, print_config)();
}
// Returns the granularity of B rows for `RangesOfNP`. Aims to avoid remote
// memory accesses or false sharing, unless there are insufficient per-package
// rows for that.
static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
size_t num_packages) {
size_t np_multiple = Allocator::QuantumBytes() / sizeof_TC;
static size_t NPMultiple(const Allocator2& allocator, size_t N,
size_t sizeof_TC, size_t nr, size_t num_packages) {
size_t np_multiple = allocator.QuantumBytes() / sizeof_TC;
// If binding, `np_multiple` is typically 1024 and `num_packages` > 1. For
// `N` < 4096, this can cause significant load imbalance. If split unevenly,
// choose a smaller multiple.
if (N % (np_multiple * num_packages)) {
const size_t min_multiple = Allocator::LineBytes() / sizeof_TC;
const size_t min_multiple = allocator.LineBytes() / sizeof_TC;
np_multiple =
PrevDivisor(min_multiple, np_multiple, N / num_packages, min_multiple);
if (HWY_UNLIKELY(np_multiple == 0)) {
@ -408,16 +411,14 @@ static size_t NPMultiple(size_t N, size_t sizeof_TC, size_t nr,
IndexRangePartition MMParallel::RangesOfNP(size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr) const {
const size_t num_packages = HWY_MIN(max_packages, pools_.NumPackages());
return StaticPartition(IndexRange(0, N), num_packages,
NPMultiple(N, sizeof_TC, nr, num_packages));
const size_t num_packages = HWY_MIN(max_packages, ctx_.pools.NumPackages());
return StaticPartition(
IndexRange(0, N), num_packages,
NPMultiple(ctx_.allocator, N, sizeof_TC, nr, num_packages));
}
MatMulEnv::MatMulEnv(const BoundedTopology& topology, NestedPools& pools)
: parallel(topology, pools), storage(parallel) {
// Ensure Allocator:Init was called.
HWY_ASSERT(Allocator::LineBytes() != 0 && Allocator::VectorBytes() != 0);
MatMulEnv::MatMulEnv(ThreadingContext2& ctx)
: ctx(ctx), parallel(ctx), storage(ctx.allocator, parallel) {
char cpu100[100];
have_timer_stop = hwy::platform::HaveTimerStop(cpu100);
}

View File

@ -24,11 +24,9 @@
#include <vector>
// IWYU pragma: begin_exports
#include "compression/compress.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "util/topology.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/base.h"
#include "hwy/bit_set.h"
@ -51,33 +49,30 @@ class MMParallel {
public:
static constexpr size_t kMaxPackages = 4;
// Both references must outlive this object.
MMParallel(const BoundedTopology& topology, NestedPools& pools)
: topology_(topology), pools_(pools) {
HWY_DASSERT(pools_.NumPackages() <= kMaxPackages);
// `ctx` must outlive this object.
MMParallel(ThreadingContext2& ctx) : ctx_(ctx) {
HWY_DASSERT(ctx_.pools.NumPackages() <= kMaxPackages);
}
// Used by tests.
NestedPools& Pools() { return pools_; }
// Initial static partitioning of B rows across packages.
IndexRangePartition RangesOfNP(size_t max_packages, size_t N,
size_t sizeof_TC, size_t nr) const;
// For `BindB` and `BindC`.
size_t Node(size_t pkg_idx) const {
return topology_.GetCluster(pkg_idx, 0).Node();
return ctx_.topology.GetCluster(pkg_idx, 0).Node();
}
// Calls `func(pkg_idx)` for each package in parallel.
template <class Func>
void ForPkg(const size_t max_packages, const Func& func) {
pools_.AllPackages().Run(0, HWY_MIN(max_packages, pools_.NumPackages()),
[&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx);
(void)task;
func(pkg_idx);
});
ctx_.pools.AllPackages().Run(
0, HWY_MIN(max_packages, ctx_.pools.NumPackages()),
[&](uint64_t task, size_t pkg_idx) {
HWY_DASSERT(task == pkg_idx);
(void)task;
func(pkg_idx);
});
}
// Cluster/CCX-aware parallel-for over B rows in `range_np`. `nx_multiple` is
@ -87,10 +82,10 @@ class MMParallel {
size_t pkg_idx, const Func& func) {
HWY_DASSERT(1 <= inner_tasks && inner_tasks <= 4);
// Single cluster: parallel-for over static partition of `range_np`.
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
const size_t num_clusters = all_clusters.NumWorkers();
if (num_clusters == 1) {
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, 0);
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, 0);
const IndexRangePartition worker_ranges = StaticPartition(
range_np, cluster.NumWorkers() * inner_tasks, nx_multiple);
return ParallelizeOneRange(
@ -106,7 +101,7 @@ class MMParallel {
ParallelizeOneRange(
nx_ranges, all_clusters,
[&](const IndexRange& nx_range, const size_t cluster_idx) {
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
// Parallel-for over sub-ranges of `cluster_range` within the cluster.
const IndexRangePartition worker_ranges = StaticPartition(
nx_range, cluster.NumWorkers() * inner_tasks, nx_multiple);
@ -122,14 +117,14 @@ class MMParallel {
void ForRangesMC_NC(const IndexRangePartition& ranges_mc,
const IndexRangePartition& ranges_nc, size_t pkg_idx,
const Func& func) {
hwy::ThreadPool& all_clusters = pools_.AllClusters(pkg_idx);
hwy::ThreadPool& all_clusters = ctx_.pools.AllClusters(pkg_idx);
// `all_clusters` is a pool with one worker per cluster in a package.
const size_t num_clusters = all_clusters.NumWorkers();
// Single (big) cluster: collapse two range indices into one parallel-for
// to reduce the number of fork-joins.
if (num_clusters == 1) {
const size_t cluster_idx = 0;
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
// Low-batch: avoid Divide/Remainder.
if (HWY_UNLIKELY(ranges_mc.NumTasks() == 1)) {
return ParallelizeOneRange(
@ -150,7 +145,7 @@ class MMParallel {
ParallelizeOneRange(
ranges_nc, all_clusters,
[&](const IndexRange range_nc, size_t cluster_idx) {
hwy::ThreadPool& cluster = pools_.Cluster(pkg_idx, cluster_idx);
hwy::ThreadPool& cluster = ctx_.pools.Cluster(pkg_idx, cluster_idx);
ParallelizeOneRange(
ranges_mc, cluster,
[&](const IndexRange& range_mc, size_t /*thread*/) {
@ -163,33 +158,33 @@ class MMParallel {
template <class Func>
void ForRangeMC(const IndexRange& range_mc, size_t pkg_idx,
const Func& func) {
pools_.Pool(pkg_idx).Run(
ctx_.pools.Pool(pkg_idx).Run(
range_mc.begin(), range_mc.end(),
[&](uint64_t row_a, size_t /*thread*/) { func(row_a); });
}
private:
const BoundedTopology& topology_;
NestedPools& pools_;
ThreadingContext2& ctx_;
};
template <typename TC> // BF16/float for C, double for partial
void BindC(size_t M, const RowPtr<TC>& C, MMParallel& parallel) {
if (!Allocator::ShouldBind()) return;
void BindC(const Allocator2& allocator, size_t M, const RowPtr<TC>& C,
MMParallel& parallel) {
if (!allocator.ShouldBind()) return;
const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, C.Cols(), sizeof(TC), kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(TC);
const size_t quantum = allocator.Quantum<TC>();
bool ok = true;
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& cols_c = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
for (size_t im = 0; im < M; ++im) {
// BindRowsToPackageNodes may not be page-aligned.
// `BindMemory` requires page alignment.
const size_t begin = hwy::RoundUpTo(cols_c.begin(), quantum);
const size_t end = hwy::RoundDownTo(cols_c.end(), quantum);
ok &= Allocator::BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
node);
ok &= allocator.BindMemory(C.Row(im) + begin, (end - begin) * sizeof(TC),
node);
}
}
if (HWY_UNLIKELY(!ok)) {
@ -212,38 +207,42 @@ class MMStorage {
// of BF16 A and B fit in 32 KiB L1, but there may be `kMaxMR` and `kNR`.
static constexpr size_t kMaxKC = 8 * 1024;
explicit MMStorage(MMParallel& parallel) {
MMStorage(const Allocator2& allocator, MMParallel& parallel)
// Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity.
: partial_storage_(
AllocateAlignedRows<double>(allocator, Extents2D(kMaxM, kMaxN))),
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_(allocator, partial_storage_.All(), kMaxN,
StrideForCyclicOffsets(kMaxN, allocator.Quantum<double>())) {
// Per-package allocation so each can decompress A into its own copy.
parallel.ForPkg(MMParallel::kMaxPackages, [&](size_t pkg_idx) {
pkg_A_[pkg_idx] = AllocateAlignedRows<BF16>(Extents2D(kMaxM, kMaxK));
pkg_A_[pkg_idx] =
AllocateAlignedRows<BF16>(allocator, Extents2D(kMaxM, kMaxK));
if (Allocator::ShouldBind()) {
if (allocator.ShouldBind()) {
const size_t node = parallel.Node(pkg_idx);
if (!Allocator::BindMemory(pkg_A_[pkg_idx].All(),
pkg_A_[pkg_idx].NumBytes(), node)) {
if (!allocator.BindMemory(pkg_A_[pkg_idx].All(),
pkg_A_[pkg_idx].NumBytes(), node)) {
HWY_WARN("Failed to bind memory for package %zu", pkg_idx);
}
}
});
// Per-worker copies of `partial` would be wasteful. We instead allocate
// one instance of the maximum matrix extents because threads write at
// false-sharing-free granularity.
partial_storage_ = AllocateAlignedRows<double>(Extents2D(kMaxM, kMaxN));
// Same stride independent of the actual C.Cols() so we can pre-bind.
partial_ = RowPtrD(partial_storage_.All(), kMaxN,
StrideForCyclicOffsets<double>(kMaxN));
// Avoid cross-package accesses.
BindC(kMaxM, partial_, parallel);
BindC(allocator, kMaxM, partial_, parallel);
}
// Returns per-package matrix view. Non-const so that `RowVectorBatch` is
// non-const, because `RowPtr` requires a non-const pointer.
RowPtrBF A(size_t pkg_idx, const Extents2D& extents) {
RowPtrBF A(const Allocator2& allocator, size_t pkg_idx,
const Extents2D& extents) {
HWY_DASSERT(extents.rows <= kMaxM);
HWY_DASSERT(extents.cols <= kMaxK);
const size_t stride = StrideForCyclicOffsets<BF16>(extents.cols);
return RowPtrBF(pkg_A_[pkg_idx].All(), extents.cols, stride);
const size_t stride =
StrideForCyclicOffsets(extents.cols, allocator.Quantum<BF16>());
return RowPtrBF(allocator, pkg_A_[pkg_idx].All(), extents.cols, stride);
}
RowPtrD Partial() const { return partial_; }
@ -431,13 +430,15 @@ class MMConfig {
static_assert(sizeof(MMConfig) == 32); // for faster indexing
#pragma pack(pop)
std::vector<MMConfig> MMCandidates(size_t M, size_t K, size_t N,
size_t sizeof_TC, size_t max_mr, size_t nr,
std::vector<MMConfig> MMCandidates(const Allocator2& allocator, size_t M,
size_t K, size_t N, size_t sizeof_TC,
size_t max_mr, size_t nr,
const IndexRangePartition& ranges_np,
bool print_config);
// State machine for choosing the best `TConfig`, which is `MMConfig` for the
// main MatMul autotuner.
// TODO: replace with hwy/auto_tune.h.
template <typename TConfig>
class MMAutoTune {
public:
@ -560,11 +561,11 @@ class MMKeys {
}
// Must only be called if not already present in `Keys()`.
void Append(Key key) {
void Append(Key key, const Allocator2& allocator) {
// Dynamic allocation because the test checks many more dimensions than
// would be reasonable to pre-allocate. DIY for alignment and padding.
if (HWY_UNLIKELY(num_unique_ >= capacity_)) {
const size_t NU64 = Allocator::VectorBytes() / sizeof(Key);
const size_t NU64 = allocator.VectorBytes() / sizeof(Key);
// Start at one vector so the size is always a multiple of N.
if (HWY_UNLIKELY(capacity_ == 0)) {
capacity_ = hwy::DivCeil(NU64, 2); // will be doubled below
@ -604,10 +605,12 @@ struct MMPerKey {
MMAutoTune<MMParA> autotune_par_a[MMParallel::kMaxPackages];
};
// Stores state shared across MatMul calls. Non-copyable.
// Stores state shared across MatMul calls. Non-copyable. `ctx` must outlive
// `MatMulEnv`.
struct MatMulEnv {
explicit MatMulEnv(const BoundedTopology& topology, NestedPools& pools);
explicit MatMulEnv(ThreadingContext2& ctx);
ThreadingContext2& ctx;
bool have_timer_stop = false;
// Enable binding: disabled in Gemma until tensors support it, enabled in
@ -684,8 +687,9 @@ struct MMZone {
// `ofs` required for compressed T.
template <typename T>
struct ConstMat {
ConstMat(const T* ptr, Extents2D extents, size_t stride, size_t ofs = 0)
: ptr(ptr), extents(extents), stride(stride), ofs(ofs) {
ConstMat() = default;
ConstMat(const T* ptr, Extents2D extents, size_t stride)
: ptr(ptr), extents(extents), stride(stride), ofs(0) {
HWY_DASSERT(ptr != nullptr);
HWY_DASSERT(stride >= extents.cols);
}
@ -717,15 +721,17 @@ struct ConstMat {
float scale = 1.0f;
// Offset to add to `ptr`; separate because T=NuqStream does not support
// pointer arithmetic.
// pointer arithmetic. This is in units of weights, and does not have anything
// to do with the interleaved NUQ tables. It should be computed via `Row()`
// to take into account the stride.
size_t ofs;
};
// For deducing T.
template <typename T>
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents, size_t stride,
size_t ofs = 0) {
return ConstMat<T>(ptr, extents, stride, ofs);
ConstMat<T> MakeConstMat(T* HWY_RESTRICT ptr, Extents2D extents,
size_t stride) {
return ConstMat<T>(ptr, extents, stride);
}
// For A argument to MatMul (activations).
@ -739,21 +745,21 @@ ConstMat<T> ConstMatFromBatch(size_t batch_size,
}
template <typename T>
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m, size_t ofs = 0) {
ConstMat<T> ConstMatFromWeights(const MatPtrT<T>& m) {
ConstMat<T> mat =
MakeConstMat(const_cast<T*>(m.data()), m.Extents(), m.Stride(), ofs);
mat.scale = m.scale();
MakeConstMat(const_cast<T*>(m.Row(0)), m.Extents(), m.Stride());
mat.scale = m.Scale();
return mat;
}
template <typename TB>
void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
MMParallel& parallel) {
if (!Allocator::ShouldBind()) return;
void BindB(const Allocator2& allocator, size_t N, size_t sizeof_TC,
const ConstMat<TB>& B, MMParallel& parallel) {
if (!allocator.ShouldBind()) return;
const IndexRangePartition ranges_np =
parallel.RangesOfNP(MMParallel::kMaxPackages, N, sizeof_TC, kNR);
const size_t quantum = Allocator::QuantumBytes() / sizeof(TB);
const size_t quantum = allocator.Quantum<TB>();
for (size_t pkg_idx = 0; pkg_idx < ranges_np.NumTasks(); ++pkg_idx) {
const IndexRange& rows_b = ranges_np.Range(pkg_idx);
const size_t node = parallel.Node(pkg_idx);
@ -765,7 +771,7 @@ void BindB(size_t N, size_t sizeof_TC, const ConstMat<TB>& B,
begin = hwy::RoundUpTo(begin, quantum);
end = hwy::RoundDownTo(end, quantum);
if (HWY_LIKELY(begin != end)) {
Allocator::BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
allocator.BindMemory(reinterpret_cast<void*>(begin), end - begin, node);
}
}
}

View File

@ -15,29 +15,25 @@
// End to end test of MatMul, comparing against a reference implementation.
#include "hwy/detect_compiler_arch.h"
#include "hwy/detect_compiler_arch.h" // IWYU pragma: keep
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support.
// double-precision support, and older x86 to speed up builds.
#if HWY_ARCH_ARM_V7
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
#else
#define HWY_DISABLED_TARGETS HWY_SCALAR
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_SSSE3 | HWY_SSE4)
#endif
#endif
#include <stddef.h>
#include <stdio.h>
#include <memory>
#include "compression/compress.h"
#include "compression/shared.h"
#include "ops/matmul.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off
@ -48,9 +44,9 @@
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "compression/test_util-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -60,57 +56,6 @@ extern int64_t first_target;
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT>
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content);
const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
mat->set_scale(0.6f); // Arbitrary value, different from 1.
return mat;
}
// extents describes the transposed matrix.
template <typename MatT>
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
// Arbitrary value, different from 1, must match GenerateMat.
mat->set_scale(0.6f);
return mat;
}
// Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
double max_row_abs_sum = 0.0;
@ -141,16 +86,19 @@ float MaxAbs(const RowVectorBatch<float>& a) {
template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const hn::ScalableTag<float> df;
const size_t cols = A.extents.cols;
const size_t B_rows = B.extents.rows;
// Round up for DecompressAndZeroPad.
RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
RowVectorBatch<float> a_batch =
AllocateAlignedRows<float>(allocator, A.extents);
RowVectorBatch<float> b_trans_batch =
AllocateAlignedRows<float>(allocator, B.extents);
RowVectorBatch<float> c_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
RowVectorBatch<float> c_slow_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
AllocateAlignedRows<float>(allocator, Extents2D(A.extents.rows, B_rows));
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
for (size_t m = 0; m < A.extents.rows; ++m) {
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
@ -224,7 +172,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols());
NestedPools& pools = env.parallel.Pools();
NestedPools& pools = env.ctx.pools;
hwy::ThreadPool& all_packages = pools.AllPackages();
const IndexRangePartition get_row_c =
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
@ -232,7 +180,7 @@ HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
get_row_c, all_packages,
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
const size_t multiple = Allocator::QuantumBytes() / sizeof(TB);
const size_t multiple = env.ctx.allocator.QuantumBytes() / sizeof(TB);
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
@ -262,7 +210,8 @@ void PrintSpeed(const char* algo, const Extents2D& A_extents,
template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
const Allocator2& allocator = env.ctx.allocator;
hwy::ThreadPool& pool = env.ctx.pools.Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
TypeName<TC>());
@ -274,24 +223,22 @@ void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc);
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
HWY_ASSERT(a && b_trans);
MatStorageT<TA> a(GenerateMat<TA>(A_extents, pool));
MatStorageT<TB> b_trans(GenerateTransposedMat<TB>(B_extents, pool));
RowVectorBatch<TC> c_slow_batch =
AllocateAlignedRows<TC>(allocator, C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(allocator, C_extents);
std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
HWY_ASSERT(add_storage);
add_storage->set_scale(1.0f);
}
MatStorageT<float> add_storage =
add ? GenerateMat<float>(Extents2D(1, cols_bc), pool)
: MatStorageT<float>("add", Extents2D(), MatPadding::kPacked);
add_storage.SetScale(1.0f);
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
const auto A = ConstMatFromWeights(a);
const auto B = ConstMatFromWeights(b_trans);
const float* add_row = add ? add_storage.PackedScale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromBatch(allocator, c_slow_batch);
const RowPtr<TC> C = RowPtrFromBatch(allocator, c_batch);
MatMulSlow(A, B, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
@ -312,22 +259,24 @@ void TestTiny() {
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
const BoundedTopology topology(BoundedSlice(0, max_packages));
Allocator::Init(topology, /*enable_bind=*/true);
const size_t max_threads = 0; // no limit
NestedPools pools(topology, max_threads, Tristate::kDefault);
ThreadingContext2::ThreadHostileInvalidate();
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
threading_args.max_packages = max_packages;
ThreadingContext2::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get());
NestedPools& pools = env.ctx.pools;
#if GEMMA_DISABLE_TOPOLOGY
if (max_packages == 2) break; // we only have one package
#else
// If less than the limit, we have already tested all num_packages.
if (topology.FullTopology().packages.size() < max_packages) break;
if (env.ctx.topology.FullTopology().packages.size() < max_packages) break;
#endif
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
topology.TopologyString(), pools.PinString());
env.ctx.topology.TopologyString(), pools.PinString());
Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
MatMulEnv env(topology, pools);
pools.MaybeStartSpinning(threading_args.spin);
for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) {
@ -336,7 +285,7 @@ void TestTiny() {
}
}
}
pools.MaybeStopSpinning(use_spinning);
pools.MaybeStopSpinning(threading_args.spin);
}
}
@ -347,12 +296,13 @@ void TestAllMatMul() {
return;
}
const BoundedTopology topology;
Allocator::Init(topology, /*enable_bind=*/true);
NestedPools pools(topology);
Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
MatMulEnv env(topology, pools);
ThreadingContext2::ThreadHostileInvalidate();
ThreadingArgs threading_args;
threading_args.bind = Tristate::kTrue;
ThreadingContext2::SetArgs(threading_args);
MatMulEnv env(ThreadingContext2::Get());
NestedPools& pools = env.ctx.pools;
pools.MaybeStartSpinning(threading_args.spin);
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
@ -417,6 +367,8 @@ void TestAllMatMul() {
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
pools.MaybeStopSpinning(threading_args.spin);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -50,8 +50,7 @@ template <class ArrayT, typename VT>
HWY_INLINE float Dot(const ArrayT& w, size_t w_ofs, const VT* vec_aligned,
size_t num) {
const hn::ScalableTag<VT> d;
return w.scale() * Dot(d, MakeConstSpan(w.data(), w.NumElements()), w_ofs,
vec_aligned, num);
return w.Scale() * Dot(d, w.Span(), w_ofs, vec_aligned, num);
}
// Simple version without tiling nor threading, but two offsets/outputs and

View File

@ -27,12 +27,13 @@
#include <type_traits> // std::enable_if_t
#include <vector>
#include "compression/compress.h"
#include "util/allocator.h"
#include "util/basics.h" // TokenAndProb
#include "util/mat.h"
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/contrib/sort/order.h"
#include "hwy/contrib/sort/vqsort.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/detect_targets.h"
#include "hwy/profiler.h"
#endif // THIRD_PARTY_GEMMA_CPP_OPS_OPS_INL_H_
@ -807,12 +808,13 @@ HWY_NOINLINE HWY_MAYBE_UNUSED TokenAndProb FusedSoftmaxAndSampleTopK(
// Each output row is the average of a 4x4 block of input rows
template <typename T>
RowVectorBatch<T> AvgPool4x4(RowVectorBatch<T>& input) {
Extents2D extents = input.Extents();
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const Extents2D extents = input.Extents();
// Input validation
HWY_DASSERT(extents.rows == 4096); // 64 * 64 = 4096 input rows
// Create output with 256 rows and same number of columns
const size_t out_rows = 256; // 16 * 16 = 256 output rows
RowVectorBatch<T> result(Extents2D{out_rows, extents.cols});
RowVectorBatch<T> result(allocator, Extents2D(out_rows, extents.cols));
const size_t input_dim = 64; // Input is 64×64
const size_t output_dim = 16; // Output is 16×16
for (size_t out_row_idx = 0; out_row_idx < output_dim; ++out_row_idx) {

View File

@ -21,14 +21,16 @@
#include <cmath>
#include "util/allocator.h"
#include "util/mat.h"
#include "hwy/base.h"
namespace gcpp {
static inline HWY_MAYBE_UNUSED RowVectorBatch<float> CreateInvTimescale(
size_t qkv_dim, bool half_rope, double base_frequency = 10000.0) {
const Allocator2& allocator, size_t qkv_dim, bool half_rope,
double base_frequency = 10000.0) {
const size_t rope_dim = half_rope ? qkv_dim / 2 : qkv_dim;
RowVectorBatch<float> inv_timescale(Extents2D(1, rope_dim / 2));
RowVectorBatch<float> inv_timescale(allocator, Extents2D(1, rope_dim / 2));
for (size_t dim = 0; dim < rope_dim / 2; ++dim) {
const double freq_exponents =
static_cast<double>(2 * dim) / static_cast<double>(rope_dim);

View File

@ -31,14 +31,12 @@
#include <random>
#include <vector>
#include "compression/compress.h" // BF16
#include "gemma/common.h"
#include "gemma/configs.h"
#include "util/allocator.h"
#include "util/app.h"
#include "util/basics.h" // BF16
#include "util/mat.h" // RowVectorBatch
#include "util/test_util.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "util/threading_context.h"
#include "hwy/tests/hwy_gtest.h"
// clang-format off
@ -388,13 +386,11 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void ScalarRopeAndMulBy(
}
void TestRopeAndMulBy() {
AppArgs app;
BoundedTopology topology = CreateTopology(app);
NestedPools pools = CreatePools(topology, app);
const Allocator2& allocator = ThreadingContext2::Get().allocator;
ModelConfig config = ConfigFromModel(Model::GEMMA2_9B);
int dim_qkv = config.layer_configs[0].qkv_dim;
RowVectorBatch<float> x(Extents2D(1, dim_qkv));
RowVectorBatch<float> x(allocator, Extents2D(1, dim_qkv));
std::mt19937 gen;
gen.seed(0x12345678);
@ -412,8 +408,8 @@ void TestRopeAndMulBy() {
std::vector<float> qactual(dim_qkv);
std::vector<float> kexpected(dim_qkv);
std::vector<float> kactual(dim_qkv);
RowVectorBatch<float> inv_timescale = gcpp::CreateInvTimescale(
config.layer_configs[0].qkv_dim,
RowVectorBatch<float> inv_timescale = CreateInvTimescale(
allocator, config.layer_configs[0].qkv_dim,
config.layer_configs[0].post_qk == PostQKType::HalfRope);
// Assert VectorizedRope computation is same as regular rope at different pos.
for (int pos = 1; pos < 500; pos++) {

View File

@ -49,8 +49,8 @@ class PaliGemmaTest : public ::testing::Test {
};
void PaliGemmaTest::InitVit(const std::string& path) {
ASSERT_NE(s_env->GetModel(), nullptr);
Gemma& model = *(s_env->GetModel());
ASSERT_NE(s_env->GetGemma(), nullptr);
Gemma& model = *(s_env->GetGemma());
image_tokens_ =
ImageTokens(Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
@ -64,7 +64,7 @@ void PaliGemmaTest::InitVit(const std::string& path) {
}
std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
Gemma& model = *(s_env->GetModel());
Gemma& model = *(s_env->GetGemma());
s_env->MutableGen().seed(0x12345678);
RuntimeConfig runtime_config = {.max_generated_tokens = 512,
.gen = &s_env->MutableGen(),
@ -92,7 +92,7 @@ std::string PaliGemmaTest::GemmaReply(const std::string& prompt_text) const{
}
void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
ASSERT_NE(s_env->GetModel(), nullptr);
ASSERT_NE(s_env->GetGemma(), nullptr);
std::string path = "paligemma/testdata/image.ppm";
InitVit(path);
for (size_t i = 0; i < num_questions; ++i) {
@ -104,7 +104,7 @@ void PaliGemmaTest::TestQuestions(const char* kQA[][2], size_t num_questions) {
}
TEST_F(PaliGemmaTest, General) {
ASSERT_NE(s_env->GetModel(), nullptr);
ASSERT_NE(s_env->GetGemma(), nullptr);
static const char* kQA_3B_mix_224[][2] = {
{"describe this image",
"A large building with two towers stands tall on the water's edge."},
@ -124,7 +124,7 @@ TEST_F(PaliGemmaTest, General) {
};
const char* (*qa)[2];
size_t num;
switch (s_env->GetModel()->Info().model) {
switch (s_env->GetGemma()->Info().model) {
case Model::PALIGEMMA_224:
qa = kQA_3B_mix_224;
num = sizeof(kQA_3B_mix_224) / sizeof(kQA_3B_mix_224[0]);
@ -135,7 +135,7 @@ TEST_F(PaliGemmaTest, General) {
break;
default:
FAIL() << "Unsupported model: "
<< s_env->GetModel()->GetModelConfig().model_name;
<< s_env->GetGemma()->GetModelConfig().model_name;
break;
}
TestQuestions(qa, num);

View File

@ -21,12 +21,12 @@ pybind_extension(
name = "gemma",
srcs = ["gemma_py.cc"],
deps = [
"//:app",
"//:allocator",
"//:benchmark_helper",
"//:gemma_args",
"//:gemma_lib",
"//compression:sfp",
"@highway//:hwy",
"@highway//:thread_pool",
],
)

View File

@ -32,9 +32,9 @@
#include "compression/shared.h"
#include "evals/benchmark_helper.h"
#include "gemma/gemma.h"
#include "util/app.h"
#include "gemma/gemma_args.h"
#include "util/allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
namespace py = pybind11;
@ -48,8 +48,9 @@ static void RemoveTrailingZeros(std::vector<int> &vec) {
class GemmaModel {
public:
GemmaModel(const gcpp::LoaderArgs& loader,
const gcpp::InferenceArgs& inference, const gcpp::AppArgs& app)
: gemma_(loader, inference, app), last_prob_(0.0f) {}
const gcpp::InferenceArgs& inference,
const gcpp::ThreadingArgs& threading)
: gemma_(threading, loader, inference), last_prob_(0.0f) {}
// Generates a single example, given a prompt and a callback to stream the
// generated tokens.
@ -168,7 +169,8 @@ class GemmaModel {
// Generate* will use this image. Throws an error for other models.
void SetImage(const py::array_t<float, py::array::c_style |
py::array::forcecast>& image) {
gcpp::Gemma& model = *(gemma_.GetModel());
const gcpp::Allocator2& allocator = gemma_.Env().ctx.allocator;
gcpp::Gemma& model = *(gemma_.GetGemma());
if (model.Info().wrapping != gcpp::PromptWrapping::PALIGEMMA) {
throw std::invalid_argument("Not a PaliGemma model.");
}
@ -183,9 +185,9 @@ class GemmaModel {
c_image.Set(height, width, ptr);
const size_t image_size = model.GetModelConfig().vit_config.image_size;
c_image.Resize(image_size, image_size);
image_tokens_ = gcpp::ImageTokens(gcpp::Extents2D(
model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
image_tokens_ = gcpp::ImageTokens(
allocator, gcpp::Extents2D(model.GetModelConfig().vit_config.seq_len,
model.GetModelConfig().model_dim));
gcpp::RuntimeConfig runtime_config = {.gen = &gemma_.MutableGen(),
.verbosity = 0};
model.GenerateImageTokens(runtime_config, c_image, image_tokens_);
@ -199,7 +201,7 @@ class GemmaModel {
if (image_tokens_.Cols() == 0) {
throw std::invalid_argument("No image set.");
}
gcpp::Gemma& model = *(gemma_.GetModel());
gcpp::Gemma& model = *(gemma_.GetGemma());
gemma_.MutableGen().seed(seed);
gcpp::RuntimeConfig& config = gemma_.MutableConfig();
config.max_generated_tokens = max_generated_tokens;
@ -247,7 +249,7 @@ class GemmaModel {
return gemma_.StringFromTokens(token_ids);
}
bool ModelIsLoaded() const { return gemma_.GetModel() != nullptr; }
bool ModelIsLoaded() const { return gemma_.GetGemma() != nullptr; }
private:
gcpp::GemmaEnv gemma_;
@ -267,7 +269,7 @@ PYBIND11_MODULE(gemma, mod) {
loader.weight_type_str = weight_type;
gcpp::InferenceArgs inference;
inference.max_generated_tokens = 512;
gcpp::AppArgs app;
gcpp::ThreadingArgs app;
app.max_threads = max_threads;
auto gemma =
std::make_unique<GemmaModel>(loader, inference, app);

View File

@ -130,233 +130,6 @@ size_t DetectTotalMiB(size_t page_bytes) {
} // namespace
static size_t line_bytes_;
static size_t vector_bytes_;
static size_t step_bytes_;
static size_t quantum_bytes_;
static size_t quantum_steps_;
static size_t l1_bytes_;
static size_t l2_bytes_;
static size_t l3_bytes_;
static bool should_bind_ = false;
size_t Allocator::LineBytes() { return line_bytes_; }
size_t Allocator::VectorBytes() { return vector_bytes_; }
size_t Allocator::StepBytes() { return step_bytes_; }
size_t Allocator::QuantumBytes() { return quantum_bytes_; }
size_t Allocator::QuantumSteps() { return quantum_steps_; }
size_t Allocator::L1Bytes() { return l1_bytes_; }
size_t Allocator::L2Bytes() { return l2_bytes_; }
size_t Allocator::L3Bytes() { return l3_bytes_; }
bool Allocator::ShouldBind() { return should_bind_; }
void Allocator::Init(const BoundedTopology& topology, bool enable_bind) {
line_bytes_ = DetectLineBytes();
vector_bytes_ = hwy::VectorBytes();
step_bytes_ = HWY_MAX(line_bytes_, vector_bytes_);
quantum_bytes_ = step_bytes_; // may overwrite below
const BoundedTopology::Cluster& cluster = topology.GetCluster(0, 0);
if (const hwy::Cache* caches = hwy::DataCaches()) {
l1_bytes_ = caches[1].size_kib << 10;
l2_bytes_ = caches[2].size_kib << 10;
l3_bytes_ = (caches[3].size_kib << 10) * caches[3].cores_sharing;
} else { // Unknown, make reasonable assumptions.
l1_bytes_ = 32 << 10;
l2_bytes_ = (cluster.PrivateKiB() ? cluster.PrivateKiB() : 256) << 10;
}
if (l3_bytes_ == 0) {
l3_bytes_ = (cluster.SharedKiB() ? cluster.SharedKiB() : 1024) << 10;
}
// Prerequisites for binding:
// - supported by the OS (currently Linux only),
// - the page size is known and 'reasonably small', preferably less than
// a fraction of MatMul row/col sizes, which for 27B are up to 144 KiB.
// - we successfully detected topology and there are multiple nodes;
// - there are multiple packages, because we shard by package_idx.
if constexpr (GEMMA_BIND) {
const size_t page_bytes = DetectPageSize();
if ((page_bytes != 0 && page_bytes <= 16 * 1024) &&
topology.NumNodes() > 1 && topology.NumPackages() > 1) {
if (enable_bind) {
// Ensure pages meet the alignment requirements of `AllocBytes`.
HWY_ASSERT(page_bytes >= quantum_bytes_);
quantum_bytes_ = page_bytes;
// Ensure MaxQuantumBytes() is an upper bound.
HWY_ASSERT(MaxQuantumBytes() >= quantum_bytes_);
quantum_bytes_ = HWY_MIN(quantum_bytes_, MaxQuantumBytes());
should_bind_ = true;
} else {
HWY_WARN(
"Multiple sockets but binding disabled. This reduces speed; "
"set or remove enable_bind to avoid this warning.");
}
}
}
HWY_DASSERT(quantum_bytes_ % step_bytes_ == 0);
quantum_steps_ = quantum_bytes_ / step_bytes_;
}
Allocator::PtrAndDeleter Allocator::AllocBytes(size_t bytes) {
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
// defends against 2K aliasing.
if (!should_bind_) {
// Perf warning if Highway's alignment is less than we want.
if (HWY_ALIGNMENT < QuantumBytes()) {
HWY_WARN(
"HWY_ALIGNMENT %d < QuantumBytes %zu: either vector or cache lines "
"are huge, enable GEMMA_BIND to avoid this warning.",
HWY_ALIGNMENT, QuantumBytes());
}
auto p = hwy::AllocateAligned<uint8_t>(bytes);
// The `hwy::AlignedFreeUniquePtr` deleter is unfortunately specific to the
// alignment scheme in aligned_allocator.cc and does not work for
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
// pointer in our own deleter.
auto call_free = [](void* ptr, size_t /*bytes*/) {
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
};
return PtrAndDeleter{p.release(), DeleterFree(call_free, bytes)};
}
// Binding, or large vector/cache line size: use platform-specific allocator.
#if HWY_OS_LINUX && !defined(__ANDROID_API__)
// `move_pages` is documented to require an anonymous/private mapping or
// `MAP_SHARED`. A normal allocation might not suffice, so we use `mmap`.
// `Init` verified that the page size is a multiple of `QuantumBytes()`.
const int prot = PROT_READ | PROT_WRITE;
const int flags = MAP_ANONYMOUS | MAP_PRIVATE;
const int fd = -1;
// Encourage transparent hugepages by rounding up to a multiple of 2 MiB.
bytes = hwy::RoundUpTo(bytes, 2ull << 20);
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
if (p == MAP_FAILED) p = nullptr;
const auto call_munmap = [](void* ptr, size_t bytes) {
const int ret = munmap(ptr, bytes);
HWY_ASSERT(ret == 0);
};
return PtrAndDeleter{p, DeleterFree(call_munmap, bytes)};
#elif HWY_OS_WIN
const auto call_free = [](void* ptr, size_t) { _aligned_free(ptr); };
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
DeleterFree(call_free, bytes)};
#else
return PtrAndDeleter{nullptr, DeleterFree(nullptr, 0)};
#endif
}
#if GEMMA_BIND && HWY_OS_LINUX
using Ret = long; // NOLINT(runtime/int)
using UL = unsigned long; // NOLINT(runtime/int)
static constexpr size_t ULBits = sizeof(UL) * 8;
// Calling via syscall avoids a dependency on libnuma.
struct SyscallWrappers {
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
unsigned flags) {
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
};
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
int* status, int flags) {
MaybeCheckInitialized(pages, count * sizeof(void*));
MaybeCheckInitialized(nodes, count * sizeof(int));
MaybeCheckInitialized(status, count * sizeof(int));
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
}
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
unsigned flags) {
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
}
};
// Returns the number of pages that are currently busy (hence not yet moved),
// and warns if there are any other reasons for not moving a page. Note that
// `move_pages` can return 0 regardless of whether all pages were moved.
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
const int* status) {
size_t num_busy = 0;
for (size_t i = 0; i < num_pages; ++i) {
if (status[i] == -EBUSY) {
++num_busy;
} else if (status[i] != static_cast<int>(node)) {
static std::atomic_flag first = ATOMIC_FLAG_INIT;
if (!first.test_and_set()) {
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
status[i], i, pages[i], node, errno);
}
}
}
return num_busy;
}
bool Allocator::BindMemory(void* ptr, size_t bytes, size_t node) {
HWY_DASSERT(should_bind_);
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
if constexpr (HWY_IS_DEBUG_BUILD) {
// Ensure the requested `node` is allowed.
UL nodes[kMaxNodes / 64] = {0};
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
nullptr, flags) == 0);
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
}
// Avoid mbind because it does not report why it failed, which is most likely
// because pages are busy, in which case we want to know which.
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
const unsigned flags = 2; // MPOL_MF_MOVE
HWY_ASSERT(bytes % quantum_bytes_ == 0);
const size_t num_pages = bytes / quantum_bytes_;
std::vector<void*> pages;
pages.reserve(num_pages);
for (size_t i = 0; i < num_pages; ++i) {
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
// Ensure the page is faulted in to prevent `move_pages` from failing,
// because freshly allocated pages may be mapped to a shared 'zero page'.
hwy::ZeroBytes(pages.back(), 8);
}
std::vector<int> nodes(num_pages, node);
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
Ret ret = SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
if (ret < 0) {
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
bytes, node, errno, status[0]);
return false;
}
const size_t num_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (HWY_UNLIKELY(num_busy != 0)) {
// Trying again is usually enough to succeed.
hwy::NanoSleep(5000);
(void)SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
const size_t still_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (HWY_UNLIKELY(still_busy != 0)) {
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
still_busy, num_busy);
}
}
return true;
}
#else
bool Allocator::BindMemory(void*, size_t, size_t) { return false; }
#endif // GEMMA_BIND && HWY_OS_LINUX
Allocator2::Allocator2(const BoundedTopology& topology, bool enable_bind) {
line_bytes_ = DetectLineBytes();
vector_bytes_ = hwy::VectorBytes();
@ -428,7 +201,7 @@ size_t Allocator2::FreeMiB() const {
#endif
}
Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
AlignedPtr2<uint8_t[]> Allocator2::AllocBytes(size_t bytes) const {
// If we are not binding, the Highway allocator is cheaper than `mmap`, and
// defends against 2K aliasing.
if (!should_bind_) {
@ -444,9 +217,10 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
// alignment scheme in aligned_allocator.cc and does not work for
// already-aligned pointers as returned by `mmap`, hence we wrap the Highway
// pointer in our own deleter.
return PtrAndDeleter{p.release(), DeleterFunc2([](void* ptr) {
hwy::FreeAlignedBytes(ptr, nullptr, nullptr);
})};
return AlignedPtr2<uint8_t[]>(p.release(), DeleterFunc2([](void* ptr) {
hwy::FreeAlignedBytes(ptr, nullptr,
nullptr);
}));
}
// Binding, or large vector/cache line size: use platform-specific allocator.
@ -460,20 +234,126 @@ Allocator2::PtrAndDeleter Allocator2::AllocBytes(size_t bytes) const {
const int fd = -1;
void* p = mmap(0, bytes, prot, flags, fd, off_t{0});
if (p == MAP_FAILED) p = nullptr;
return PtrAndDeleter{p, DeleterFunc2([bytes](void* ptr) {
HWY_ASSERT(munmap(ptr, bytes) == 0);
})};
return AlignedPtr2<uint8_t[]>(static_cast<uint8_t*>(p),
DeleterFunc2([bytes](void* ptr) {
HWY_ASSERT(munmap(ptr, bytes) == 0);
}));
#elif HWY_OS_WIN
const size_t alignment = HWY_MAX(vector_bytes_, line_bytes_);
return PtrAndDeleter{_aligned_malloc(bytes, alignment),
DeleterFunc2([](void* ptr) { _aligned_free(ptr); })};
return AlignedPtr2<uint8_t[]>(
static_cast<uint8_t*>(_aligned_malloc(bytes, alignment)),
DeleterFunc2([](void* ptr) { _aligned_free(ptr); }));
#else
return PtrAndDeleter{nullptr, DeleterFunc2()};
return AlignedPtr2<uint8_t[]>(nullptr, DeleterFunc2());
#endif
}
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
return Allocator::BindMemory(ptr, bytes, node);
#if GEMMA_BIND && HWY_OS_LINUX
using Ret = long; // NOLINT(runtime/int)
using UL = unsigned long; // NOLINT(runtime/int)
static constexpr size_t ULBits = sizeof(UL) * 8;
// Calling via syscall avoids a dependency on libnuma.
struct SyscallWrappers {
static Ret mbind(void* ptr, UL bytes, int mode, const UL* nodes, UL max_nodes,
unsigned flags) {
MaybeCheckInitialized(nodes, hwy::DivCeil(max_nodes, ULBits) * sizeof(UL));
return syscall(__NR_mbind, ptr, bytes, mode, max_nodes, max_nodes, flags);
};
static Ret move_pages(int pid, UL count, void** pages, const int* nodes,
int* status, int flags) {
MaybeCheckInitialized(pages, count * sizeof(void*));
MaybeCheckInitialized(nodes, count * sizeof(int));
MaybeCheckInitialized(status, count * sizeof(int));
return syscall(__NR_move_pages, pid, count, pages, nodes, status, flags);
}
static Ret get_mempolicy(int* mode, UL* nodes, UL max_node, void* addr,
unsigned flags) {
return syscall(__NR_get_mempolicy, mode, nodes, max_node, addr, flags);
}
};
// Returns the number of pages that are currently busy (hence not yet moved),
// and warns if there are any other reasons for not moving a page. Note that
// `move_pages` can return 0 regardless of whether all pages were moved.
size_t CountBusyPages(size_t num_pages, size_t node, void** pages,
const int* status) {
size_t num_busy = 0;
for (size_t i = 0; i < num_pages; ++i) {
if (status[i] == -EBUSY) {
++num_busy;
} else if (status[i] != static_cast<int>(node)) {
static std::atomic_flag first = ATOMIC_FLAG_INIT;
if (!first.test_and_set()) {
HWY_WARN("Error %d moving pages[%zu]=%p to node %zu (errno %d).",
status[i], i, pages[i], node, errno);
}
}
}
return num_busy;
}
bool Allocator2::BindMemory(void* ptr, size_t bytes, size_t node) const {
HWY_DASSERT(should_bind_);
constexpr size_t kMaxNodes = 1024; // valid for x86/x64, and "enough"
if constexpr (HWY_IS_DEBUG_BUILD) {
// Ensure the requested `node` is allowed.
UL nodes[kMaxNodes / 64] = {0};
const unsigned flags = 4; // MPOL_F_MEMS_ALLOWED
HWY_ASSERT(SyscallWrappers::get_mempolicy(nullptr, nodes, kMaxNodes,
nullptr, flags) == 0);
HWY_ASSERT(nodes[node / 64] & (1ull << (node % 64)));
}
// Avoid mbind because it does not report why it failed, which is most likely
// because pages are busy, in which case we want to know which.
// `MPOL_MF_MOVE_ALL` requires cap sys_nice, which is not easy to set.
const unsigned flags = 2; // MPOL_MF_MOVE
HWY_ASSERT(bytes % quantum_bytes_ == 0);
const size_t num_pages = bytes / quantum_bytes_;
std::vector<void*> pages;
pages.reserve(num_pages);
for (size_t i = 0; i < num_pages; ++i) {
pages.push_back(static_cast<uint8_t*>(ptr) + i * quantum_bytes_);
// Ensure the page is faulted in to prevent `move_pages` from failing,
// because freshly allocated pages may be mapped to a shared 'zero page'.
hwy::ZeroBytes(pages.back(), 8);
}
std::vector<int> nodes(num_pages, node);
std::vector<int> status(num_pages, static_cast<int>(kMaxNodes));
Ret ret = SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
if (ret < 0) {
HWY_WARN("Failed to bind %p %zu to node %zu (errno %d) status %d.", ptr,
bytes, node, errno, status[0]);
return false;
}
const size_t num_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (HWY_UNLIKELY(num_busy != 0)) {
// Trying again is usually enough to succeed.
hwy::NanoSleep(5000);
(void)SyscallWrappers::move_pages(
/*pid=*/0, num_pages, pages.data(), nodes.data(), status.data(), flags);
const size_t still_busy =
CountBusyPages(num_pages, node, pages.data(), status.data());
if (HWY_UNLIKELY(still_busy != 0)) {
HWY_WARN("BindMemory: %zu pages still busy after retrying %zu.",
still_busy, num_busy);
}
}
return true;
}
#else
bool Allocator2::BindMemory(void*, size_t, size_t) const { return false; }
#endif // GEMMA_BIND && HWY_OS_LINUX
} // namespace gcpp

View File

@ -30,307 +30,8 @@
#include "hwy/base.h"
// IWYU pragma: end_exports
#include "hwy/aligned_allocator.h"
namespace gcpp {
// Points to an adapter lambda that calls `FreeAlignedBytes` or `munmap`. The
// `bytes` argument is required for the latter.
using FreeFunc = void (*)(void* mem, size_t bytes);
// Custom deleter for std::unique_ptr that calls `FreeFunc`. T is POD.
class DeleterFree {
public:
// `MatStorageT` requires this to be default-constructible.
DeleterFree() : free_func_(nullptr), bytes_(0) {}
DeleterFree(FreeFunc free_func, size_t bytes)
: free_func_(free_func), bytes_(bytes) {}
template <typename T>
void operator()(T* p) const {
free_func_(p, bytes_);
}
private:
FreeFunc free_func_;
size_t bytes_;
};
// Wrapper that also calls the destructor for non-POD T.
class DeleterDtor {
public:
DeleterDtor() {}
DeleterDtor(size_t num, DeleterFree free) : num_(num), free_(free) {}
template <typename T>
void operator()(T* p) const {
for (size_t i = 0; i < num_; ++i) {
p[i].~T();
}
free_(p);
}
private:
size_t num_; // not the same as free_.bytes_ / sizeof(T)!
DeleterFree free_;
};
// Unique (move-only) pointer to an aligned array of POD T.
template <typename T>
using AlignedPtr = std::unique_ptr<T[], DeleterFree>;
// Unique (move-only) pointer to an aligned array of non-POD T.
template <typename T>
using AlignedClassPtr = std::unique_ptr<T[], DeleterDtor>;
// Both allocation, binding, and row accessors depend on the sizes of memory
// pages and cache lines. To avoid having to pass `Allocator&` everywhere, we
// use `Monostate` (static members).
class Allocator {
public:
// Must be called at least once before any other function. Not thread-safe,
// hence only call this from the main thread.
// TODO: remove enable_bind once Gemma tensors support binding.
static void Init(const BoundedTopology& topology, bool enable_bind = false);
// Bytes per cache line, or a reasonable guess if unknown. Used to choose
// ranges such that there will be no false sharing.
static size_t LineBytes();
// Bytes per full vector. Used to compute loop steps.
static size_t VectorBytes();
// Work granularity that avoids false sharing and partial vectors.
static size_t StepBytes(); // = HWY_MAX(LineBytes(), VectorBytes())
// Granularity like `StepBytes()`, but when NUMA may be involved.
static size_t QuantumBytes();
// Upper bound on `QuantumBytes()`, for stack allocations.
static constexpr size_t MaxQuantumBytes() { return 4096; }
static size_t QuantumSteps(); // = QuantumBytes() / StepBytes()
// L1 and L2 are typically per core.
static size_t L1Bytes();
static size_t L2Bytes();
// Clusters often share an L3. We return the total size per package.
static size_t L3Bytes();
// Returns pointer aligned to `QuantumBytes()`.
template <typename T>
static AlignedPtr<T> Alloc(size_t num) {
constexpr size_t kSize = sizeof(T);
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
// Fail if the `bytes = num * kSize` computation overflowed.
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
if (check != num) return AlignedPtr<T>();
PtrAndDeleter pd = AllocBytes(bytes);
return AlignedPtr<T>(static_cast<T*>(pd.p), pd.deleter);
}
// Same as Alloc, but calls constructor(s) with `args`.
template <typename T, class... Args>
static AlignedClassPtr<T> AllocClasses(size_t num, Args&&... args) {
constexpr size_t kSize = sizeof(T);
constexpr bool kIsPow2 = (kSize & (kSize - 1)) == 0;
constexpr size_t kBits = hwy::detail::ShiftCount(kSize);
static_assert(!kIsPow2 || (1ull << kBits) == kSize, "ShiftCount has a bug");
const size_t bytes = kIsPow2 ? num << kBits : num * kSize;
// Fail if the `bytes = num * kSize` computation overflowed.
const size_t check = kIsPow2 ? bytes >> kBits : bytes / kSize;
if (check != num) return AlignedClassPtr<T>();
PtrAndDeleter pd = AllocBytes(bytes);
T* p = static_cast<T*>(pd.p);
for (size_t i = 0; i < num; ++i) {
new (p + i) T(std::forward<Args>(args)...);
}
return AlignedClassPtr<T>(p, DeleterDtor(num, pd.deleter));
}
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
// control over memory placement and multiple packages and NUMA nodes.
static bool ShouldBind();
// Attempts to move(!) `[p, p + bytes)` to the given NUMA node, which is
// typically `BoundedTopology::GetCluster(package_idx, cluster_idx).node`.
// Writes zeros to SOME of the memory. Only call if `ShouldBind()`.
// `p` and `bytes` must be multiples of `QuantumBytes()`.
static bool BindMemory(void* p, size_t bytes, size_t node);
private:
// Type-erased so this can be implemented in allocator.cc.
struct PtrAndDeleter {
void* p;
DeleterFree deleter;
};
static PtrAndDeleter AllocBytes(size_t bytes);
};
// Value of `stride` to pass to `RowVectorBatch` to enable the "cyclic offsets"
// optimization. If `Allocator::ShouldBind()`, `Allocator::QuantumBytes()` is
// typically 4KiB. To avoid remote accesses, we would thus pad each row to that,
// which results in 4K aliasing and/or cache conflict misses. `RowPtr` is able
// to prevent that by pulling rows forward by a cyclic offset, which is still a
// multiple of the cache line size. This requires an additional
// `Allocator::QuantumBytes()` of padding after also rounding up to that.
template <typename T>
constexpr size_t StrideForCyclicOffsets(size_t cols) {
const size_t quantum = Allocator::MaxQuantumBytes() / sizeof(T);
return hwy::RoundUpTo(cols, quantum) + quantum;
}
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
// we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(Extents2D extents, size_t stride = 0) : extents_(extents) {
if (stride == 0) {
stride_ = extents_.cols;
} else {
HWY_ASSERT(stride >= extents_.cols);
stride_ = stride;
}
// Allow binding the entire matrix.
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
Allocator::QuantumBytes() / sizeof(T));
mem_ = Allocator::Alloc<T>(padded);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
size_t Stride() const { return stride_; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
private:
AlignedPtr<T> mem_;
Extents2D extents_;
size_t stride_;
};
// Returns `num` rounded up to an odd number of cache lines. This is used to
// compute strides. An odd number of cache lines prevents 2K aliasing and is
// coprime with the cache associativity, which reduces conflict misses.
template <typename T>
static HWY_INLINE size_t RoundUpToOddLines(size_t num, size_t line_bytes) {
HWY_DASSERT(line_bytes >= 32);
HWY_DASSERT(line_bytes % sizeof(T) == 0);
const size_t lines = hwy::DivCeil(num * sizeof(T), line_bytes);
const size_t padded_num = (lines | 1) * line_bytes / sizeof(T);
HWY_DASSERT(padded_num >= num);
return padded_num;
}
template <typename T>
RowVectorBatch<T> AllocateAlignedRows(Extents2D extents) {
return RowVectorBatch<T>(extents, StrideForCyclicOffsets<T>(extents.cols));
}
// Lightweight version of `MatPtr` used for the C argument of `MatMul`, because
// it is always float and does not support compressed T, but does support an
// arbitrary stride >= cols.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
public:
RowPtr() = default; // for `MMPtrs`.
RowPtr(T* HWY_RESTRICT row0, size_t cols, size_t stride)
: row0_(row0),
stride_(stride),
step_(static_cast<uint32_t>(Allocator::StepBytes())),
cols_(static_cast<uint32_t>(cols)),
row_mask_(Allocator::QuantumSteps() - 1) {
HWY_DASSERT(stride >= cols);
HWY_DASSERT(row_mask_ != ~size_t{0});
if constexpr (HWY_IS_DEBUG_BUILD) {
if (stride < StrideForCyclicOffsets<T>(cols)) {
static bool once;
if (!once) {
once = true;
HWY_WARN(
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
"T=%zu; this forces us to disable cyclic offsets.",
stride, cols, sizeof(T));
}
row_mask_ = 0;
}
}
}
RowPtr(T* HWY_RESTRICT row0, size_t cols) : RowPtr(row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const {
// How much of the previous row's padding to consume.
const size_t pad_bytes = (r & row_mask_) * step_;
HWY_DASSERT(pad_bytes < Allocator::QuantumBytes());
return row0_ + stride_ * r - pad_bytes;
}
size_t Cols() const { return cols_; }
size_t Stride() const { return stride_; }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
// The caller might not have padded enough, so disable the padding in Row().
// Rows will now be exactly `stride` elements apart. This is used when
// writing to the KV cache via MatMul.
row_mask_ = 0;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < cols_);
HWY_DASSERT(cols <= cols_ - c);
return RowPtr<T>(Row(r) + c, cols, stride_);
}
private:
T* HWY_RESTRICT row0_;
size_t stride_;
uint32_t step_; // Copy from Allocator::LineBytes() to improve locality.
uint32_t cols_;
size_t row_mask_;
};
#pragma pack(pop)
using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>;
// For C argument to MatMul.
template <typename T>
RowPtr<T> RowPtrFromBatch(RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(row_vectors.All(), row_vectors.Cols(), row_vectors.Stride());
}
// Custom deleter for types without a dtor, but where the deallocation requires
// state, e.g. a lambda with *by-value* capture.
class DeleterFunc2 {
@ -420,15 +121,22 @@ class Allocator2 {
size_t TotalMiB() const { return total_mib_; }
size_t FreeMiB() const;
// Returns pointer aligned to `QuantumBytes()`.
// Returns byte pointer aligned to `QuantumBytes()`, without calling
// constructors nor destructors on deletion. Type-erased so this can be
// implemented in `allocator.cc` and called by `MatOwner`.
AlignedPtr2<uint8_t[]> AllocBytes(size_t bytes) const;
// Returns pointer aligned to `QuantumBytes()`, without calling constructors
// nor destructors on deletion.
template <typename T>
AlignedPtr2<T[]> Alloc(size_t num) const {
const size_t bytes = num * sizeof(T);
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
HWY_ASSERT(bytes / sizeof(T) == num);
PtrAndDeleter pd = AllocBytes(bytes);
return AlignedPtr2<T[]>(static_cast<T*>(pd.p), pd.deleter);
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
return AlignedPtr2<T[]>(HWY_RCAST_ALIGNED(T*, p8.release()),
p8.get_deleter());
}
// Same as Alloc, but calls constructor(s) with `args` and the deleter will
@ -439,12 +147,12 @@ class Allocator2 {
// Fail if the `bytes = num * sizeof(T)` computation overflowed.
HWY_ASSERT(bytes / sizeof(T) == num);
PtrAndDeleter pd = AllocBytes(bytes);
T* p = static_cast<T*>(pd.p);
AlignedPtr2<uint8_t[]> p8 = AllocBytes(bytes);
T* p = HWY_RCAST_ALIGNED(T*, p8.release());
for (size_t i = 0; i < num; ++i) {
new (p + i) T(std::forward<Args>(args)...);
}
return AlignedClassPtr2<T>(p, DeleterDtor2(num, pd.deleter));
return AlignedClassPtr2<T>(p, DeleterDtor2(num, p8.get_deleter()));
}
// Returns whether `BindMemory` can/should be called, i.e. we have page-level
@ -458,13 +166,6 @@ class Allocator2 {
bool BindMemory(void* p, size_t bytes, size_t node) const;
private:
// Type-erased so this can be implemented in allocator.cc.
struct PtrAndDeleter {
void* p;
DeleterFunc2 deleter;
};
PtrAndDeleter AllocBytes(size_t bytes) const;
size_t line_bytes_;
size_t vector_bytes_;
size_t step_bytes_;

View File

@ -23,7 +23,7 @@
#include <algorithm> // std::transform
#include <string>
#include "compression/io.h"
#include "compression/io.h" // Path
#include "util/basics.h" // Tristate
#include "hwy/base.h" // HWY_ABORT

100
util/mat.cc Normal file
View File

@ -0,0 +1,100 @@
// Copyright 2024 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/mat.h"
#include <stddef.h>
#include <stdint.h>
#include "util/threading_context.h"
#include "hwy/base.h"
#include "hwy/per_target.h" // VectorBytes
#include "hwy/profiler.h"
namespace gcpp {
void CopyMat(const MatPtr& from, MatPtr& to) {
PROFILER_FUNC;
HWY_ASSERT(to.Rows() == from.Rows() && to.Cols() == from.Cols());
HWY_ASSERT(to.GetType() == from.GetType());
if (to.IsPacked() && from.IsPacked()) {
HWY_ASSERT(to.PackedBytes() == from.PackedBytes());
hwy::CopyBytes(from.Packed(), to.Packed(), to.PackedBytes());
return;
}
const size_t row_bytes = to.Cols() * to.ElementBytes();
for (size_t r = 0; r < to.Rows(); ++r) {
const uint8_t* from_row = from.RowT<uint8_t>(r);
uint8_t* to_row = to.RowT<uint8_t>(r);
hwy::CopyBytes(from_row, to_row, row_bytes);
}
}
void ZeroInit(MatPtr& mat) {
PROFILER_FUNC;
HWY_ASSERT_M(mat.HasPtr(), mat.Name());
if (mat.IsPacked()) {
hwy::ZeroBytes(mat.Packed(), mat.PackedBytes());
return;
}
const size_t row_bytes = mat.Cols() * mat.ElementBytes();
for (size_t r = 0; r < mat.Rows(); ++r) {
hwy::ZeroBytes(mat.RowT<uint8_t>(r), row_bytes);
}
}
// Returns `num` rounded up to an odd number of cache lines. This would also
// prevent 4K aliasing and is coprime with the cache associativity, which
// might reduce conflict misses, but we instead use `StrideForCyclicOffsets`.
static size_t RoundUpToOddLines(size_t num, size_t line_bytes,
size_t element_bytes) {
HWY_DASSERT(line_bytes >= 32);
HWY_DASSERT(line_bytes % element_bytes == 0);
const size_t lines = hwy::DivCeil(num * element_bytes, line_bytes);
const size_t padded_num = (lines | 1) * line_bytes / element_bytes;
HWY_DASSERT(padded_num >= num);
return padded_num;
}
static size_t Stride(const Allocator2& allocator, const MatPtr& mat,
MatPadding padding) {
switch (padding) {
case MatPadding::kPacked:
default:
return mat.Cols();
case MatPadding::kOdd:
return RoundUpToOddLines(mat.Cols(), allocator.LineBytes(),
mat.ElementBytes());
case MatPadding::kCyclic:
return StrideForCyclicOffsets(
mat.Cols(), allocator.QuantumBytes() / mat.ElementBytes());
}
}
void MatOwner::AllocateFor(MatPtr& mat, MatPadding padding) {
const Allocator2& allocator = ThreadingContext2::Get().allocator;
const size_t stride = Stride(allocator, mat, padding);
const size_t num = mat.Rows() * stride;
// `compress-inl` requires up to 2 BF16 vectors of padding. `MatPadding`
// might not be enough, hence add extra. `MatT` is at least one byte, which
// is half of BF16, hence adding `VectorBytes` *elements* is enough.
const size_t bytes = (num + hwy::VectorBytes()) * mat.ElementBytes();
// Allow binding the entire matrix.
const size_t padded_bytes =
hwy::RoundUpTo(bytes, allocator.QuantumBytes() / mat.ElementBytes());
storage_ = allocator.AllocBytes(padded_bytes);
mat.SetPtr(storage_.get(), stride);
}
} // namespace gcpp

532
util/mat.h Normal file
View File

@ -0,0 +1,532 @@
// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Tensor metadata and in-memory representation.
#ifndef THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
#define THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_
#include <stddef.h>
#include <stdint.h>
#include <random>
#include <string>
// IWYU pragma: begin_exports
#include "compression/fields.h"
#include "compression/shared.h" // Type
#include "gemma/tensor_index.h"
#include "util/allocator.h"
#include "util/basics.h" // Extents2D
// IWYU pragma: end_exports
#include "hwy/base.h"
namespace gcpp {
// Type-erased, non-owning pointer and metadata for rank-1 or 2 tensors (vector
// or matrix). Base class of the non-type-erased `MatPtrT`. Use this class
// to store hetereogeneous tensor references in a vector.
//
// Copyable, (de)serializable via `fields.h` for `model_store.h`.
class MatPtr : public IFields {
public:
MatPtr() = default;
// `name`: see `SetName`. Note that `stride` is initially `cols` and only
// differs after deserializing, or calling `SetPtr`.
MatPtr(const char* name, Type type, Extents2D extents)
: rows_(static_cast<uint32_t>(extents.rows)),
cols_(static_cast<uint32_t>(extents.cols)) {
SetName(name);
SetType(type);
SetPtr(nullptr, cols_);
}
// Copying allowed because the metadata is small.
MatPtr(const MatPtr& other) = default;
MatPtr& operator=(const MatPtr& other) = default;
virtual ~MatPtr() = default;
// Only for use by ctor, `AllocateFor` and 'loading' memory-mapped tensors.
void SetPtr(void* ptr, size_t stride) {
HWY_ASSERT(stride >= Cols());
ptr_ = ptr;
stride_ = static_cast<uint32_t>(stride);
// NUQ streams must not be padded because that would change the position of
// the group tables.
if (type_ == Type::kNUQ) HWY_ASSERT(IsPacked());
}
bool HasPtr() const { return ptr_ != nullptr; }
bool IsPacked() const { return stride_ == cols_; }
const void* Packed() const {
HWY_DASSERT_M(IsPacked(), name_.c_str());
return ptr_;
}
void* Packed() {
HWY_DASSERT_M(IsPacked(), name_.c_str());
return ptr_;
}
// Returns size in bytes for purposes of copying/initializing or I/O. Must
// only be called if `IsPacked`.
size_t PackedBytes() const {
HWY_DASSERT_M(IsPacked(), name_.c_str());
// num_elements_ already includes the NUQ tables.
return num_elements_ * element_bytes_;
}
// Works for any kind of padding.
template <typename T>
T* MutableRowT(size_t row) const {
HWY_DASSERT(row < rows_);
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
}
template <typename T>
T* RowT(size_t row) {
HWY_DASSERT(row < rows_);
return HWY_RCAST_ALIGNED(T*, ptr_) + row * stride_;
}
template <typename T>
const T* RowT(size_t row) const {
HWY_DASSERT(row < rows_);
return HWY_RCAST_ALIGNED(const T*, ptr_) + row * stride_;
}
Type GetType() const { return type_; }
void SetType(Type type) {
type_ = type;
element_bytes_ = static_cast<uint32_t>(hwy::DivCeil(TypeBits(type), 8));
num_elements_ = static_cast<uint32_t>(ComputeNumElements(type, Extents()));
}
bool IsEmpty() const { return rows_ == 0 || cols_ == 0; }
size_t Rows() const { return rows_; }
size_t Cols() const { return cols_; }
Extents2D Extents() const { return Extents2D(rows_, cols_); }
// Offset by which to advance pointers to the next row.
size_t Stride() const { return stride_; }
// For use by `BlobStore`, `CopyMat` and `ZeroInit`.
size_t ElementBytes() const { return element_bytes_; }
// Decoded elements should be multiplied by this to restore their original
// range. This is required because `SfpStream` can only encode a limited range
// of magnitudes.
float Scale() const { return scale_; }
void SetScale(float scale) { scale_ = scale; }
// Name is a terse identifier. `MakeKey` in `blob_store.cc` requires that it
// be <= 16 bytes including prefixes/suffixes. The initial name set by the
// ctor is for the tensor, but `ForEachTensor` in `weights.h` adds a per-layer
// suffix, and when loading, we call `SetName` with that.
const char* Name() const override { return name_.c_str(); }
void SetName(const char* name) {
name_ = name;
HWY_ASSERT_M(name_.size() <= sizeof(hwy::uint128_t), name);
}
void VisitFields(IFieldsVisitor& visitor) override {
// Order determines the order of serialization and must not change.
visitor(name_);
visitor(type_);
visitor(element_bytes_);
visitor(num_elements_);
visitor(rows_);
visitor(cols_);
visitor(scale_);
visitor(stride_);
}
protected:
// For initializing `num_elements_`: "elements" are how many objects we
// actually store in order to represent rows * cols values. For NUQ, this is
// greater because it includes additional per-group tables. This is the only
// place where we compute this fixup. Note that elements are independent of
// padding, which is anyway not supported for NUQ because `compress-inl.h`
// assumes a contiguous stream for its group indexing.
static size_t ComputeNumElements(Type type, Extents2D extents) {
const size_t num_elements = extents.Area();
if (type == Type::kNUQ) {
// `CompressedArrayElements` is a wrapper function that has the same
// effect, but that requires a template argument, not `type`.
return NuqStream::PackedEnd(num_elements);
}
return num_elements;
}
std::string name_; // See `SetName`.
Type type_;
// Most members are u32 because that is the preferred type of fields.h.
// Bytes per element. This is fully determined by `type_`, but stored here
// for convenience and backward compatibility.
uint32_t element_bytes_ = 0;
// Number of elements to store (including NUQ tables but not padding).
// This a function of `type_` and `Extents()` and stored for compatibility.
uint32_t num_elements_ = 0;
uint32_t rows_ = 0;
uint32_t cols_ = 0;
float scale_ = 1.0f; // multiplier for each value, for MatMul.
// Non-owning pointer, must not be freed. The underlying memory must outlive
// this object.
void* ptr_ = nullptr; // not serialized
// Offset by which to advance pointers to the next row, >= `cols_`.
uint32_t stride_;
};
// Non-type erased version of `MatPtr`. Use this when operating on the values.
template <typename MatT>
class MatPtrT : public MatPtr {
public:
// Runtime-specified shape.
MatPtrT(const char* name, Extents2D extents)
: MatPtr(name, TypeEnum<MatT>(), extents) {}
// Take shape from `TensorInfo` to avoid duplicating it in the caller.
MatPtrT(const char* name, const TensorInfo* tensor)
: MatPtrT<MatT>(name, ExtentsFromInfo(tensor)) {}
// Find `TensorInfo` by name in `TensorIndex`.
MatPtrT(const char* name, const TensorIndex& tensor_index)
: MatPtrT<MatT>(name, tensor_index.FindName(name)) {}
// Copying allowed because the metadata is small.
MatPtrT(const MatPtr& other) : MatPtr(other) {}
MatPtrT& operator=(const MatPtr& other) {
MatPtr::operator=(other);
return *this;
}
MatPtrT(const MatPtrT& other) = default;
MatPtrT& operator=(const MatPtrT& other) = default;
// Returns the entire tensor for use by `backprop/*`. Verifies layout is
// `kPacked`. Preferably call `Row` instead, which works for either layout.
MatT* Packed() {
HWY_DASSERT_M(IsPacked(), name_.c_str());
return HWY_RCAST_ALIGNED(MatT*, ptr_);
}
const MatT* Packed() const {
HWY_DASSERT_M(IsPacked(), name_.c_str());
return HWY_RCAST_ALIGNED(const MatT*, ptr_);
}
// As `Packed()`, plus checks the scale is 1.0 because callers will ignore it.
// This is typically used for `MatMul` bias vectors and norm weights.
const MatT* PackedScale1() const {
HWY_DASSERT(Scale() == 1.0f);
return Packed();
}
const MatT* Row(size_t row) const { return this->RowT<MatT>(row); }
MatT* Row(size_t row) { return this->RowT<MatT>(row); }
// For `compress-inl.h` functions, which assume contiguous streams and thus
// require packed layout.
PackedSpan<const MatT> Span() const {
HWY_ASSERT(IsPacked());
return MakeConstSpan(Row(0), num_elements_);
}
PackedSpan<MatT> Span() {
HWY_ASSERT(IsPacked());
return MakeSpan(Row(0), num_elements_);
}
};
// Calls `func` with a dynamic_cast of `MatPtr` to `MatPtrT<T>`, plus the
// optional `args`.
template <class Func, typename... Args>
decltype(auto) CallUpcasted(Type type, MatPtr* base, const Func& func,
Args&&... args) {
HWY_ASSERT(base != nullptr);
if (type == Type::kF32) {
return func(dynamic_cast<MatPtrT<float>*>(base),
std::forward<Args>(args)...);
} else if (type == Type::kBF16) {
return func(dynamic_cast<MatPtrT<BF16>*>(base),
std::forward<Args>(args)...);
} else if (type == Type::kSFP) {
return func(dynamic_cast<MatPtrT<SfpStream>*>(base),
std::forward<Args>(args)...);
} else if (type == Type::kNUQ) {
return func(dynamic_cast<MatPtrT<NuqStream>*>(base),
std::forward<Args>(args)...);
} else {
HWY_ABORT("Type %d unknown.", static_cast<int>(type));
}
}
void CopyMat(const MatPtr& from, MatPtr& to);
void ZeroInit(MatPtr& mat);
template <typename T>
void RandInit(MatPtrT<T>& x, T stddev, std::mt19937& gen) {
std::normal_distribution<T> dist(0.0, stddev);
for (size_t r = 0; r < x.Rows(); ++r) {
T* row = x.Row(r);
for (size_t c = 0; c < x.Cols(); ++c) {
row[c] = dist(gen);
}
}
}
// Sufficient value of `stride` to enable the "cyclic offsets" optimization. If
// `Allocator2::ShouldBind()`, `Allocator2::QuantumBytes()` is typically 4KiB.
// To avoid remote accesses, we would thus pad each row to that, which results
// in 4K aliasing and/or cache conflict misses. `RowPtr` is able to prevent that
// by pulling rows forward by a cyclic offset, which is still a multiple of the
// cache line size. This requires an additional `Allocator2::QuantumBytes()` of
// padding after also rounding up to that, which considerably increases size for
// tall and skinny tensors.
static inline size_t StrideForCyclicOffsets(size_t cols, size_t quantum) {
return hwy::RoundUpTo(cols, quantum) + quantum;
}
// Constexpr version (upper bound) for allocating storage in MatMul.
template <typename T>
constexpr size_t MaxStrideForCyclicOffsets(size_t cols) {
constexpr size_t quantum = Allocator2::MaxQuantum<T>();
return hwy::RoundUpTo(cols, quantum) + quantum;
}
// Our tensors are always row-major. This enum indicates how much (if any)
// padding comes after each row.
enum class MatPadding {
// None, stride == cols. `compress-inl.h` requires this layout because its
// interface assumes a continuous 1D array, without awareness of rows. Note
// that tensors which were written via `compress-inl.h` (i.e. most in
// `BlobStore`) are not padded, which also extends to memory-mapped tensors.
// However, `BlobStore` is able to insert padding via row-wise I/O when
// reading from disk via `Mode::kRead`.
//
// `backprop/*` also requires this layout because it indexes directly into
// the storage instead of calling `Row()`.
kPacked,
// Enough to round up to an odd number of cache lines, which can reduce
// cache conflict misses or 4K aliasing.
kOdd,
// Enough to enable the "cyclic offsets" optimization for `MatMul`.
kCyclic,
};
// Type-erased, allows storing `AlignedPtr2<T[]>` for various T in the same
// vector.
class MatOwner {
public:
MatOwner() = default;
// Allow move for `MatStorageT`.
MatOwner(MatOwner&&) = default;
MatOwner& operator=(MatOwner&&) = default;
// Allocates the type/extents indicated by `mat` and sets its pointer.
void AllocateFor(MatPtr& mat, MatPadding padding);
private:
AlignedPtr2<uint8_t[]> storage_;
};
// `MatStorageT` IS-A `MatPtrT` and HAS-A `MatOwner`. Used by `backprop/` and
// tests to allocate and access tensors of a known type. By contrast, the
// heterogeneous model weights are owned by vectors of `MatOwner`.
template <typename MatT>
class MatStorageT : public MatPtrT<MatT> {
public:
MatStorageT(const char* name, Extents2D extents, MatPadding padding)
: MatPtrT<MatT>(name, extents) {
owner_.AllocateFor(*this, padding);
}
~MatStorageT() = default;
// Allow move for backprop/activations.
MatStorageT(MatStorageT&&) = default;
MatStorageT& operator=(MatStorageT&&) = default;
private:
MatOwner owner_;
};
// Helper factory function for use by `backprop/` to avoid specifying the
// `MatPadding` argument everywhere.
template <typename T>
MatStorageT<T> MakePacked(const char* name, size_t rows, size_t cols) {
return MatStorageT<T>(name, Extents2D(rows, cols), MatPadding::kPacked);
}
// Lightweight version of `MatPtr` used by matmul-inl.h for padded tensors with
// seekable (non-NUQ) T. This has less metadata, but support for cyclic offsets.
#pragma pack(push, 1) // power of two size
template <typename T>
class RowPtr {
public:
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols,
size_t stride)
: row0_(row0),
stride_(stride),
row_mask_(
static_cast<uint32_t>(allocator.QuantumStepMask() & 0xFFFFFFFFu)),
cols_(static_cast<uint32_t>(cols)),
step_bytes_(static_cast<uint32_t>(allocator.StepBytes())),
quantum_bytes_(allocator.QuantumBytes()) {
HWY_DASSERT(stride >= cols);
HWY_DASSERT(row_mask_ != ~uint32_t{0});
if (stride < StrideForCyclicOffsets(cols, quantum_bytes_ / sizeof(T))) {
row_mask_ = 0;
if constexpr (HWY_IS_DEBUG_BUILD) {
static bool once;
if (stride != cols && !once) {
once = true;
HWY_WARN(
"Check why RowPtr stride=%zu < StrideForCyclicOffsets(cols=%zu), "
"T=%zu; this forces us to disable cyclic offsets.",
stride, cols, sizeof(T));
}
}
}
}
RowPtr(const Allocator2& allocator, T* HWY_RESTRICT row0, size_t cols)
: RowPtr(allocator, row0, cols, cols) {}
T* HWY_RESTRICT Row(size_t r) const {
// How much of the previous row's padding to consume.
const size_t pad_bytes = (r & row_mask_) * step_bytes_;
HWY_DASSERT(pad_bytes < static_cast<size_t>(quantum_bytes_));
return row0_ + stride_ * r - pad_bytes;
}
size_t Cols() const { return static_cast<size_t>(cols_); }
size_t Stride() const { return stride_; }
void SetStride(size_t stride) {
HWY_DASSERT(stride >= Cols());
stride_ = stride;
// The caller might not have padded enough, so disable the padding in Row().
// Rows will now be exactly `stride` elements apart. This is used when
// writing to the KV cache via MatMul.
row_mask_ = 0;
}
// Returns 2D subrange whose top-left is `r, c` and width is `cols`.
RowPtr<T> View(size_t r, size_t c, size_t cols) const {
HWY_DASSERT(c < Cols());
HWY_DASSERT(cols <= Cols() - c);
return RowPtr<T>(Row(r) + c, cols, stride_, row_mask_, step_bytes_,
quantum_bytes_);
}
private:
// For `View()`.
RowPtr(T* new_row0, size_t new_cols, size_t stride, uint32_t row_mask,
uint32_t step_bytes, uint32_t quantum_bytes)
: row0_(new_row0),
stride_(stride),
row_mask_(row_mask),
cols_(new_cols),
step_bytes_(step_bytes),
quantum_bytes_(quantum_bytes) {}
T* HWY_RESTRICT row0_;
size_t stride_;
uint32_t row_mask_;
uint32_t cols_;
uint32_t step_bytes_;
uint32_t quantum_bytes_;
};
#pragma pack(pop)
using RowPtrBF = RowPtr<BF16>;
using RowPtrF = RowPtr<float>;
using RowPtrD = RowPtr<double>;
// Owns dynamically-allocated aligned memory for a batch of row vectors.
// This can be seen as a (batch_size x cols) matrix. Unlike `RowPtr`, this owns
// the memory. Unlike `MatPtr`, this lacks metadata.
// TODO: replace with `MatStorageT`.
template <typename T>
class RowVectorBatch {
public:
// Default ctor for Activations ctor.
RowVectorBatch() = default;
// Main ctor, called from Activations::Allocate. If `stride` = 0, the default,
// we default to tightly packed rows (`stride = cols`).
// WARNING: not all call sites support `stride` != cols.
// TODO: once they do, remove stride and behave like AllocateAlignedRows here.
RowVectorBatch(const Allocator2& allocator, Extents2D extents,
size_t stride = 0)
: extents_(extents) {
if (stride == 0) {
stride_ = extents_.cols;
} else {
HWY_ASSERT(stride >= extents_.cols);
stride_ = stride;
}
// Allow binding the entire matrix.
const size_t padded = hwy::RoundUpTo(extents_.rows * stride_,
allocator.QuantumBytes() / sizeof(T));
mem_ = allocator.Alloc<T>(padded);
}
// Move-only
RowVectorBatch(RowVectorBatch&) noexcept = delete;
RowVectorBatch& operator=(RowVectorBatch&) noexcept = delete;
RowVectorBatch(RowVectorBatch&&) noexcept = default;
RowVectorBatch& operator=(RowVectorBatch&&) noexcept = default;
size_t BatchSize() const { return extents_.rows; }
size_t Cols() const { return extents_.cols; }
size_t Stride() const { return stride_; }
Extents2D Extents() const { return extents_; }
// Returns the given row vector of length `Cols()`.
T* Batch(size_t batch_idx) {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
const T* Batch(size_t batch_idx) const {
HWY_DASSERT(batch_idx < BatchSize());
return mem_.get() + batch_idx * stride_;
}
// For MatMul or other operations that process the entire batch at once.
// TODO: remove once we only use Mat.
T* All() { return mem_.get(); }
const T* Const() const { return mem_.get(); }
size_t NumBytes() const { return BatchSize() * stride_ * sizeof(T); }
private:
AlignedPtr2<T[]> mem_;
Extents2D extents_;
size_t stride_;
};
template <typename T>
RowPtr<T> RowPtrFromBatch(const Allocator2& allocator,
RowVectorBatch<T>& row_vectors) {
return RowPtr<T>(allocator, row_vectors.All(), row_vectors.Cols(),
row_vectors.Stride());
}
template <typename T>
RowVectorBatch<T> AllocateAlignedRows(const Allocator2& allocator,
Extents2D extents) {
return RowVectorBatch<T>(
allocator, extents,
StrideForCyclicOffsets(extents.cols,
allocator.QuantumBytes() / sizeof(T)));
}
} // namespace gcpp
#endif // THIRD_PARTY_GEMMA_CPP_UTIL_MAT_H_

View File

@ -13,12 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/threading.h"
#include "util/threading.h" // NOT threading_context..
// to ensure there is no deadlock.
#include <stdio.h>
#include <algorithm> // std::sort
#include <atomic>
#include <memory>
#include <optional>
#include <vector>
@ -69,7 +71,7 @@ class Pinning {
const int bytes_written =
snprintf(buf, sizeof(buf), "P%zu X%02zu C%03d", pkg_idx, cluster_idx,
static_cast<int>(task));
HWY_ASSERT(bytes_written < sizeof(buf));
HWY_ASSERT(bytes_written < static_cast<int>(sizeof(buf)));
hwy::SetThreadName(buf, 0); // does not support varargs
if (HWY_LIKELY(want_pin_)) {
@ -107,16 +109,16 @@ static Pinning& GetPinning() {
return pinning;
}
static PoolPtr MakePool(size_t num_workers,
static PoolPtr MakePool(const Allocator2& allocator, size_t num_workers,
std::optional<size_t> node = std::nullopt) {
// `ThreadPool` expects the number of threads to create, which is one less
// than the number of workers, but avoid underflow if zero.
const size_t num_threads = num_workers == 0 ? 0 : num_workers - 1;
PoolPtr ptr = Allocator::AllocClasses<hwy::ThreadPool>(1, num_threads);
PoolPtr ptr = allocator.AllocClasses<hwy::ThreadPool>(1, num_threads);
const size_t bytes =
hwy::RoundUpTo(sizeof(hwy::ThreadPool), Allocator::QuantumBytes());
if (node.has_value() && Allocator::ShouldBind()) {
Allocator::BindMemory(ptr.get(), bytes, node.value());
hwy::RoundUpTo(sizeof(hwy::ThreadPool), allocator.QuantumBytes());
if (node.has_value() && allocator.ShouldBind()) {
allocator.BindMemory(ptr.get(), bytes, node.value());
}
return ptr;
}
@ -133,22 +135,22 @@ static size_t DivideMaxAcross(const size_t max, const size_t instances) {
return max;
}
NestedPools::NestedPools(const BoundedTopology& topology, size_t max_threads,
NestedPools::NestedPools(const BoundedTopology& topology,
const Allocator2& allocator, size_t max_threads,
Tristate pin) {
GetPinning().SetPolicy(pin);
packages_.resize(topology.NumPackages());
all_packages_ = MakePool(packages_.size());
all_packages_ = MakePool(allocator, packages_.size());
const size_t max_workers_per_package =
DivideMaxAcross(max_threads, packages_.size());
// Each worker in all_packages_, including the main thread, will be the
// calling thread of an all_clusters[0].Run, and hence pinned to one of the
// calling thread of an all_clusters->Run, and hence pinned to one of the
// `cluster.lps` if `pin`.
all_packages_[0].Run(
0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
HWY_ASSERT(pkg_idx == thread); // each thread has one task
packages_[pkg_idx] =
Package(topology, pkg_idx, max_workers_per_package);
});
all_packages_->Run(0, packages_.size(), [&](uint64_t pkg_idx, size_t thread) {
HWY_ASSERT(pkg_idx == thread); // each thread has one task
packages_[pkg_idx] =
Package(topology, allocator, pkg_idx, max_workers_per_package);
});
all_pinned_ = GetPinning().AllPinned(&pin_string_);
@ -172,28 +174,29 @@ static inline size_t CapIfNonZero(size_t num, size_t max_or_zero) {
return (max_or_zero == 0) ? num : HWY_MIN(num, max_or_zero);
}
NestedPools::Package::Package(const BoundedTopology& topology, size_t pkg_idx,
NestedPools::Package::Package(const BoundedTopology& topology,
const Allocator2& allocator, size_t pkg_idx,
size_t max_workers_per_package) {
// Pre-allocate because elements are set concurrently.
clusters_.resize(topology.NumClusters(pkg_idx));
const size_t max_workers_per_cluster =
DivideMaxAcross(max_workers_per_package, clusters_.size());
all_clusters_ =
MakePool(clusters_.size(), topology.GetCluster(pkg_idx, 0).Node());
all_clusters_ = MakePool(allocator, clusters_.size(),
topology.GetCluster(pkg_idx, 0).Node());
// Parallel so we also pin the calling worker in `all_clusters` to
// `cluster.lps`.
all_clusters_[0].Run(
0, all_clusters_[0].NumWorkers(), [&](size_t cluster_idx, size_t thread) {
all_clusters_->Run(
0, all_clusters_->NumWorkers(), [&](size_t cluster_idx, size_t thread) {
HWY_ASSERT(cluster_idx == thread); // each thread has one task
const BoundedTopology::Cluster& cluster =
topology.GetCluster(pkg_idx, cluster_idx);
clusters_[cluster_idx] =
MakePool(CapIfNonZero(cluster.Size(), max_workers_per_cluster),
cluster.Node());
clusters_[cluster_idx] = MakePool(
allocator, CapIfNonZero(cluster.Size(), max_workers_per_cluster),
cluster.Node());
// Pin workers AND the calling thread from `all_clusters`.
GetPinning().MaybePin(pkg_idx, cluster_idx, cluster,
clusters_[cluster_idx][0]);
*clusters_[cluster_idx]);
});
}

View File

@ -23,6 +23,7 @@
// IWYU pragma: begin_exports
#include "util/allocator.h"
#include "util/args.h"
#include "util/basics.h" // Tristate
#include "util/topology.h"
#include "hwy/base.h" // HWY_ASSERT
@ -37,7 +38,7 @@ namespace gcpp {
// Page-aligned on NUMA systems so we can bind to a NUMA node. This also allows
// moving because it is a typedef to `std::unique_ptr`.
using PoolPtr = AlignedClassPtr<hwy::ThreadPool>;
using PoolPtr = AlignedClassPtr2<hwy::ThreadPool>;
// Creates a hierarchy of thread pools according to `BoundedTopology`: one with
// a thread per enabled package; for each of those, one with a thread per
@ -73,10 +74,8 @@ class NestedPools {
// would cause huge slowdowns when spinning, the `BoundedSlice` arguments
// only impose upper bounds on the number of detected packages and clusters
// rather than defining the actual number of threads.
//
// Caller must have called `Allocator::Init` before this.
NestedPools(const BoundedTopology& topology, size_t max_threads = 0,
Tristate pin = Tristate::kDefault);
NestedPools(const BoundedTopology& topology, const Allocator2& allocator,
size_t max_threads = 0, Tristate pin = Tristate::kDefault);
bool AllPinned() const { return all_pinned_; }
@ -103,7 +102,7 @@ class NestedPools {
}
size_t NumPackages() const { return packages_.size(); }
hwy::ThreadPool& AllPackages() { return all_packages_[0]; }
hwy::ThreadPool& AllPackages() { return *all_packages_; }
hwy::ThreadPool& AllClusters(size_t pkg_idx) {
HWY_DASSERT(pkg_idx < NumPackages());
return packages_[pkg_idx].AllClusters();
@ -149,36 +148,36 @@ class NestedPools {
class Package {
public:
Package() = default; // for vector
Package(const BoundedTopology& topology, size_t pkg_idx,
size_t max_workers_per_package);
Package(const BoundedTopology& topology, const Allocator2& allocator,
size_t pkg_idx, size_t max_workers_per_package);
size_t NumClusters() const { return clusters_.size(); }
size_t MaxWorkersPerCluster() const {
size_t max_workers_per_cluster = 0;
for (const PoolPtr& cluster : clusters_) {
max_workers_per_cluster =
HWY_MAX(max_workers_per_cluster, cluster[0].NumWorkers());
HWY_MAX(max_workers_per_cluster, cluster->NumWorkers());
}
return max_workers_per_cluster;
}
size_t TotalWorkers() const {
size_t total_workers = 0;
for (const PoolPtr& cluster : clusters_) {
total_workers += cluster[0].NumWorkers();
total_workers += cluster->NumWorkers();
}
return total_workers;
}
hwy::ThreadPool& AllClusters() { return all_clusters_[0]; }
hwy::ThreadPool& AllClusters() { return *all_clusters_; }
hwy::ThreadPool& Cluster(size_t cluster_idx) {
HWY_DASSERT(cluster_idx < clusters_.size());
return clusters_[cluster_idx][0];
return *clusters_[cluster_idx];
}
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_clusters_[0].SetWaitMode(wait_mode);
all_clusters_->SetWaitMode(wait_mode);
for (PoolPtr& cluster : clusters_) {
cluster[0].SetWaitMode(wait_mode);
cluster->SetWaitMode(wait_mode);
}
}
@ -188,7 +187,7 @@ class NestedPools {
}; // Package
void SetWaitMode(hwy::PoolWaitMode wait_mode) {
all_packages_[0].SetWaitMode(wait_mode);
all_packages_->SetWaitMode(wait_mode);
for (Package& package : packages_) {
package.SetWaitMode(wait_mode);
}

View File

@ -33,6 +33,13 @@ static std::mutex s_ctx_mutex;
s_ctx_mutex.unlock();
}
/*static*/ bool ThreadingContext2::IsInitialized() {
s_ctx_mutex.lock();
const bool initialized = !!s_ctx;
s_ctx_mutex.unlock();
return initialized;
}
/*static*/ ThreadingContext2& ThreadingContext2::Get() {
// We do not bother with double-checked locking because it requires an
// atomic pointer, but we prefer to use unique_ptr for simplicity. Also,

View File

@ -98,6 +98,10 @@ class ThreadingContext2 {
// is expected to be called early in the program, before threading starts.
static void SetArgs(const ThreadingArgs& args);
// Returns whether `Get()` has already been called, typically used to avoid
// calling `SetArgs` after that, because it would assert.
static bool IsInitialized();
// Returns a reference to the singleton after initializing it if necessary.
// When initializing, uses the args passed to `SetArgs`, or defaults.
//

View File

@ -13,8 +13,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "util/threading.h"
#include <stddef.h>
#include <stdio.h>
@ -22,9 +20,9 @@
#include "gmock/gmock.h"
#include "gtest/gtest.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "hwy/aligned_allocator.h"
#include "util/threading_context.h"
#include "hwy/aligned_allocator.h" // Span
#include "hwy/auto_tune.h"
#include "hwy/base.h" // HWY_ASSERT
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -385,9 +383,7 @@ TEST(ThreadingTest, BenchJoin) {
}
};
BoundedTopology topology;
Allocator::Init(topology, true);
NestedPools pools(topology);
NestedPools& pools = ThreadingContext2::Get().pools;
// Use last package because the main thread has been pinned to it.
const size_t pkg_idx = pools.NumPackages() - 1;