Makes the entire runtime_config passed into the activations constructor.

PiperOrigin-RevId: 845153671
This commit is contained in:
Balazs Racz 2025-12-16 01:56:18 -08:00 committed by Copybara-Service
parent 44dfd69b9b
commit baa69dfb78
4 changed files with 28 additions and 33 deletions

View File

@ -141,7 +141,6 @@ cc_test(
":kv_cache", ":kv_cache",
":mat", ":mat",
":matmul", ":matmul",
":query",
":test_util", ":test_util",
":threading_context", ":threading_context",
":weights", ":weights",
@ -643,7 +642,6 @@ cc_test(
":kv_cache", ":kv_cache",
":mat", ":mat",
":matmul_env", ":matmul_env",
":query",
":test_util", ":test_util",
":threading_context", ":threading_context",
":weights", ":weights",

View File

@ -48,7 +48,7 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
struct AttentionActivations { struct AttentionActivations {
AttentionActivations( AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config, const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, AttentionImpl attention_impl, size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
const Allocator& allocator, const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs) std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: // `vocab_size == 0` means it is for Vit part, VitAttention is still : // `vocab_size == 0` means it is for Vit part, VitAttention is still
@ -279,8 +279,7 @@ struct Activations {
s_w_linear_w(config.num_layers, max_workers), s_w_linear_w(config.num_layers, max_workers),
attention_impl(runtime_config.attention_impl), attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len, attention_storage(config, layer_config, batch_size, seq_len,
runtime_config.attention_impl, ctx.allocator, runtime_config, ctx.allocator, row_ptrs),
row_ptrs),
attention(config, seq_len, attention_storage) { attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0); HWY_ASSERT(batch_size != 0);

View File

@ -83,8 +83,8 @@ struct TestModelState {
state.mat_owners, 43); state.mat_owners, 43);
AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator, AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator,
state.mat_owners, 44); state.mat_owners, 44);
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners,
state.mat_owners, 45); 45);
layer.Fixup(state.mat_owners, state.ctx); layer.Fixup(state.mat_owners, state.ctx);
} }
@ -101,9 +101,10 @@ struct TestAttentionState {
: num_tokens(num_tokens), : num_tokens(num_tokens),
qbatch_size(qbatch_size), qbatch_size(qbatch_size),
batch_size(qbatch_size * num_tokens), batch_size(qbatch_size * num_tokens),
runtime_config{.attention_impl = attention_impl},
tokens(num_tokens), tokens(num_tokens),
attention_storage_(model_state.config, model_state.layer_config, attention_storage_(model_state.config, model_state.layer_config,
batch_size, num_tokens, attention_impl, batch_size, num_tokens, runtime_config,
state.ctx.allocator, row_ptrs_), state.ctx.allocator, row_ptrs_),
attention(model_state.config, num_tokens, attention_storage_) { attention(model_state.config, num_tokens, attention_storage_) {
for (size_t i = 0; i < qbatch_size; ++i) { for (size_t i = 0; i < qbatch_size; ++i) {
@ -276,8 +277,8 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}}, -66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75, {{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
-46, -22, -19.375, -16.125, -148, 20.875}, -46, -22, -19.375, -16.125, -148, 20.875},
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
33, 10.9375, -52.5, 23.25, 75}}, 10.9375, -52.5, 23.25, 75}},
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25, {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
-34.75, 18, -52, 100, -186, -75.5, 50.75}, -34.75, 18, -52, 100, -186, -75.5, 50.75},
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, {7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
@ -366,10 +367,9 @@ const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = {
-4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766, -4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766,
-7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187, -7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187,
2.60931611}}, 2.60931611}},
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, {{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925,
-16.3434925, -4.75156116, -1.99114823, 3.99918842, -5.95400572, -4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064,
10.8700314, 1.07596064, 0.30389142, 8.39548779, -5.11913681, 5.45641088, 0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337},
-5.63240337},
{-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355, {-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355,
-1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147, -1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147,
-5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052, -5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052,
@ -381,14 +381,12 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
{{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423, {{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423,
-1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332, -1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332,
-4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026}, -4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026},
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, {-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801,
3.52626801, -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463,
1.42195463, 0.301399827, -4.40214968, -2.12298298, 9.27825642, 0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}},
-0.690600872}},
{{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635, {{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635,
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, -0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508,
9.5985508, -10.6630878, -11.9006901, 0.851743698, 0.581826329, -10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929},
5.21927929},
{-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228, {-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228,
3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248, 3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248,
-7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251, -7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251,
@ -411,13 +409,11 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
-7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822, -7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822,
16.0471916}, 16.0471916},
{6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875, {6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875,
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392,
2.44881392, 1.99794042, -9.19855404, -4.02383137, -3.63013959, 1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}},
-5.65853405}}, {{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125,
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826,
11.8203125, 1.81672478, -1.42535269, -5.26496315, -5.31612349, 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216},
-4.19499826, 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617,
3.5296216},
{7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912, {7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912,
0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329, 0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329,
3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}}, 3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}},

View File

@ -112,7 +112,9 @@ void TestFlashAttention(size_t target_parallelism) {
const LayerConfig& layer_config = config.layer_configs[0]; const LayerConfig& layer_config = config.layer_configs[0];
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
InferenceArgs inference_args; InferenceArgs inference_args;
inference_args.attention_impl = "flash";
RuntimeConfig runtime_config; RuntimeConfig runtime_config;
inference_args.CopyTo(runtime_config);
KVCache kv_cache(config, inference_args, ctx.allocator); KVCache kv_cache(config, inference_args, ctx.allocator);
MatMulEnv env(ctx); MatMulEnv env(ctx);
Activations activations(runtime_config, config, Activations activations(runtime_config, config,
@ -127,8 +129,8 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t batch_size = kOuter; const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs; std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention_storage(config, layer_config, batch_size, AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, AttentionImpl::kFlash, kOuter, runtime_config, ctx.allocator,
ctx.allocator, row_ptrs); row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage); AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim; const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner); ASSERT_EQ(qkv_dim, kInner);