diff --git a/BUILD.bazel b/BUILD.bazel index 606a2fb..a9631dc 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -141,7 +141,6 @@ cc_test( ":kv_cache", ":mat", ":matmul", - ":query", ":test_util", ":threading_context", ":weights", @@ -643,7 +642,6 @@ cc_test( ":kv_cache", ":mat", ":matmul_env", - ":query", ":test_util", ":threading_context", ":weights", diff --git a/gemma/activations.h b/gemma/activations.h index 704c3ee..adb6d02 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -48,7 +48,7 @@ static inline float ChooseQueryScale(const ModelConfig& config) { struct AttentionActivations { AttentionActivations( const ModelConfig& config, const LayerConfig& layer_config, - size_t batch_size, size_t seq_len, AttentionImpl attention_impl, + size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config, const Allocator& allocator, std::vector>& row_ptrs) : // `vocab_size == 0` means it is for Vit part, VitAttention is still @@ -129,7 +129,7 @@ struct AttentionActivations { // `inv_timescale*` are not batched. } - MatStorageT q; // query + MatStorageT q; // query MatStorageT q_bf; MatStorageT q_T; // Transposed to maximize attention speed. @@ -138,8 +138,8 @@ struct AttentionActivations { MatStorageT vit_C; MatStorageT pre_att_rms_out; - MatStorageT att; // attention vector - MatStorageT att_out; // attention output + MatStorageT att; // attention vector + MatStorageT att_out; // attention output MatStorageT softmax_max; // see OnlineSoftmaxState MatStorageT softmax_d; // see OnlineSoftmaxState // Accumulation of attention outputs over heads @@ -279,8 +279,7 @@ struct Activations { s_w_linear_w(config.num_layers, max_workers), attention_impl(runtime_config.attention_impl), attention_storage(config, layer_config, batch_size, seq_len, - runtime_config.attention_impl, ctx.allocator, - row_ptrs), + runtime_config, ctx.allocator, row_ptrs), attention(config, seq_len, attention_storage) { HWY_ASSERT(batch_size != 0); diff --git a/gemma/attention_test.cc b/gemma/attention_test.cc index 19e1f06..53f1d01 100644 --- a/gemma/attention_test.cc +++ b/gemma/attention_test.cc @@ -83,8 +83,8 @@ struct TestModelState { state.mat_owners, 43); AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator, state.mat_owners, 44); - AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, - state.mat_owners, 45); + AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners, + 45); layer.Fixup(state.mat_owners, state.ctx); } @@ -101,9 +101,10 @@ struct TestAttentionState { : num_tokens(num_tokens), qbatch_size(qbatch_size), batch_size(qbatch_size * num_tokens), + runtime_config{.attention_impl = attention_impl}, tokens(num_tokens), attention_storage_(model_state.config, model_state.layer_config, - batch_size, num_tokens, attention_impl, + batch_size, num_tokens, runtime_config, state.ctx.allocator, row_ptrs_), attention(model_state.config, num_tokens, attention_storage_) { for (size_t i = 0; i < qbatch_size; ++i) { @@ -276,8 +277,8 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = { -66.5, -0.84765625, -46.5, -152, -2.9375, -81}}, {{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75, -46, -22, -19.375, -16.125, -148, 20.875}, - {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, - 33, 10.9375, -52.5, 23.25, 75}}, + {-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33, + 10.9375, -52.5, 23.25, 75}}, {{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25, -34.75, 18, -52, 100, -186, -75.5, 50.75}, {7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5, @@ -366,10 +367,9 @@ const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = { -4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766, -7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187, 2.60931611}}, - {{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, - -16.3434925, -4.75156116, -1.99114823, 3.99918842, -5.95400572, - 10.8700314, 1.07596064, 0.30389142, 8.39548779, -5.11913681, 5.45641088, - -5.63240337}, + {{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925, + -4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064, + 0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337}, {-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355, -1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147, -5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052, @@ -381,14 +381,12 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = { {{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423, -1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332, -4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026}, - {-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, - 3.52626801, -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, - 1.42195463, 0.301399827, -4.40214968, -2.12298298, 9.27825642, - -0.690600872}}, + {-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801, + -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463, + 0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}}, {{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635, - -0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, - 9.5985508, -10.6630878, -11.9006901, 0.851743698, 0.581826329, - 5.21927929}, + -0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508, + -10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929}, {-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228, 3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248, -7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251, @@ -411,13 +409,11 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = { -7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822, 16.0471916}, {6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875, - 4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, - 2.44881392, 1.99794042, -9.19855404, -4.02383137, -3.63013959, - -5.65853405}}, - {{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, - 11.8203125, 1.81672478, -1.42535269, -5.26496315, -5.31612349, - -4.19499826, 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, - 3.5296216}, + 4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392, + 1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}}, + {{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125, + 1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826, + 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216}, {7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912, 0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329, 3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}}, diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index cecede0..f0a90fa 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -112,7 +112,9 @@ void TestFlashAttention(size_t target_parallelism) { const LayerConfig& layer_config = config.layer_configs[0]; const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry); InferenceArgs inference_args; + inference_args.attention_impl = "flash"; RuntimeConfig runtime_config; + inference_args.CopyTo(runtime_config); KVCache kv_cache(config, inference_args, ctx.allocator); MatMulEnv env(ctx); Activations activations(runtime_config, config, @@ -127,8 +129,8 @@ void TestFlashAttention(size_t target_parallelism) { const size_t batch_size = kOuter; std::vector> row_ptrs; AttentionActivations attention_storage(config, layer_config, batch_size, - kOuter, AttentionImpl::kFlash, - ctx.allocator, row_ptrs); + kOuter, runtime_config, ctx.allocator, + row_ptrs); AttentionActivationsPtrs attention(config, kOuter, attention_storage); const size_t qkv_dim = layer_config.qkv_dim; ASSERT_EQ(qkv_dim, kInner);