mirror of https://github.com/google/gemma.cpp.git
Makes the entire runtime_config passed into the activations constructor.
PiperOrigin-RevId: 845153671
This commit is contained in:
parent
44dfd69b9b
commit
baa69dfb78
|
|
@ -141,7 +141,6 @@ cc_test(
|
|||
":kv_cache",
|
||||
":mat",
|
||||
":matmul",
|
||||
":query",
|
||||
":test_util",
|
||||
":threading_context",
|
||||
":weights",
|
||||
|
|
@ -643,7 +642,6 @@ cc_test(
|
|||
":kv_cache",
|
||||
":mat",
|
||||
":matmul_env",
|
||||
":query",
|
||||
":test_util",
|
||||
":threading_context",
|
||||
":weights",
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
|
|||
struct AttentionActivations {
|
||||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
|
||||
size_t batch_size, size_t seq_len, const RuntimeConfig& runtime_config,
|
||||
const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||
|
|
@ -279,8 +279,7 @@ struct Activations {
|
|||
s_w_linear_w(config.num_layers, max_workers),
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
runtime_config.attention_impl, ctx.allocator,
|
||||
row_ptrs),
|
||||
runtime_config, ctx.allocator, row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -83,8 +83,8 @@ struct TestModelState {
|
|||
state.mat_owners, 43);
|
||||
AllocateAndFillRandom(layer.gating_einsum_w, state.ctx.allocator,
|
||||
state.mat_owners, 44);
|
||||
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator,
|
||||
state.mat_owners, 45);
|
||||
AllocateAndFillRandom(layer.linear_w, state.ctx.allocator, state.mat_owners,
|
||||
45);
|
||||
layer.Fixup(state.mat_owners, state.ctx);
|
||||
}
|
||||
|
||||
|
|
@ -101,9 +101,10 @@ struct TestAttentionState {
|
|||
: num_tokens(num_tokens),
|
||||
qbatch_size(qbatch_size),
|
||||
batch_size(qbatch_size * num_tokens),
|
||||
runtime_config{.attention_impl = attention_impl},
|
||||
tokens(num_tokens),
|
||||
attention_storage_(model_state.config, model_state.layer_config,
|
||||
batch_size, num_tokens, attention_impl,
|
||||
batch_size, num_tokens, runtime_config,
|
||||
state.ctx.allocator, row_ptrs_),
|
||||
attention(model_state.config, num_tokens, attention_storage_) {
|
||||
for (size_t i = 0; i < qbatch_size; ++i) {
|
||||
|
|
@ -276,8 +277,8 @@ const float kGoldenAttSums[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
|||
-66.5, -0.84765625, -46.5, -152, -2.9375, -81}},
|
||||
{{3.984375, 83, -41.75, 39.5, -203, 110, -76, 131, 0.4609375, -44.5, -63.75,
|
||||
-46, -22, -19.375, -16.125, -148, 20.875},
|
||||
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5,
|
||||
33, 10.9375, -52.5, 23.25, 75}},
|
||||
{-47, -19.5, 58, 81.5, 21.75, -30, -118, 44.25, -149, 22.5, 188, -66.5, 33,
|
||||
10.9375, -52.5, 23.25, 75}},
|
||||
{{64, -31, -89, -92.5, -11.1875, -54.75, -302, 3.453125, -108, 39.25,
|
||||
-34.75, 18, -52, 100, -186, -75.5, 50.75},
|
||||
{7.6875, -80, -40, 32.25, -30.25, 90, -41, 44.25, -140, -2.4375, 82.5,
|
||||
|
|
@ -366,10 +367,9 @@ const float kGoldenK[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
|||
-4.42512083, 1.78077614, -3.25167561, 0.864362717, 0.474019766,
|
||||
-7.92327404, -2.27795148, -0.436354101, -3.15722394, 0.415780187,
|
||||
2.60931611}},
|
||||
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216,
|
||||
-16.3434925, -4.75156116, -1.99114823, 3.99918842, -5.95400572,
|
||||
10.8700314, 1.07596064, 0.30389142, 8.39548779, -5.11913681, 5.45641088,
|
||||
-5.63240337},
|
||||
{{-9.43858051, 0.391518891, -2.74012518, 4.9842453, 7.48263216, -16.3434925,
|
||||
-4.75156116, -1.99114823, 3.99918842, -5.95400572, 10.8700314, 1.07596064,
|
||||
0.30389142, 8.39548779, -5.11913681, 5.45641088, -5.63240337},
|
||||
{-1.22347319, 9.57339382, -1.31736016, -5.02770805, -4.81617355,
|
||||
-1.96618557, -0.456317186, 12.6451035, -1.50221801, 6.7991147,
|
||||
-5.97842169, 1.85410941, -8.44729, 0.378282309, 0.0442156792, 17.6773052,
|
||||
|
|
@ -381,14 +381,12 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
|||
{{2.77553034, -7.67514181, -1.60433948, 4.67795134, -1.75084186, 8.57896423,
|
||||
-1.15065813, -3.75088787, -4.7442131, -1.68890858, -10.0202332,
|
||||
-4.20167446, 9.36844635, 13.7364845, 11.5634, 2.95288706, 2.89380026},
|
||||
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325,
|
||||
3.52626801, -10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271,
|
||||
1.42195463, 0.301399827, -4.40214968, -2.12298298, 9.27825642,
|
||||
-0.690600872}},
|
||||
{-4.79950905, -1.66658688, 4.14471292, -4.95649052, -5.4200325, 3.52626801,
|
||||
-10.9432049, 0.338347554, -1.53204226, 0.473476171, -0.58271, 1.42195463,
|
||||
0.301399827, -4.40214968, -2.12298298, 9.27825642, -0.690600872}},
|
||||
{{-10.6566734, 4.12785721, 4.54053593, -1.39667869, -1.55028772, 0.20508635,
|
||||
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249,
|
||||
9.5985508, -10.6630878, -11.9006901, 0.851743698, 0.581826329,
|
||||
5.21927929},
|
||||
-0.00620913506, 2.93214, -0.788117647, 2.78032446, -2.68898249, 9.5985508,
|
||||
-10.6630878, -11.9006901, 0.851743698, 0.581826329, 5.21927929},
|
||||
{-0.322291255, 2.63848567, -2.30808377, -13.0153809, 2.74378228,
|
||||
3.21460533, 0.688529968, 2.37544608, 6.06825066, 4.57566404, 1.17124248,
|
||||
-7.96587658, -2.65279341, 4.75271225, -4.09937954, -10.3570251,
|
||||
|
|
@ -411,13 +409,11 @@ const float kGoldenV[kNumTokens][kQBatchSize][kDimsToCompare] = {
|
|||
-7.11484337, 2.53943753, -0.652261257, 9.77392, 3.53345847, -9.62052822,
|
||||
16.0471916},
|
||||
{6.89768124, 2.36394405, -2.08569574, -0.682706833, 3.38872, -6.28313875,
|
||||
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208,
|
||||
2.44881392, 1.99794042, -9.19855404, -4.02383137, -3.63013959,
|
||||
-5.65853405}},
|
||||
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828,
|
||||
11.8203125, 1.81672478, -1.42535269, -5.26496315, -5.31612349,
|
||||
-4.19499826, 7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617,
|
||||
3.5296216},
|
||||
4.79594612, 4.93417454, -6.40791416, -10.7355442, -5.66094208, 2.44881392,
|
||||
1.99794042, -9.19855404, -4.02383137, -3.63013959, -5.65853405}},
|
||||
{{1.64614546, -3.93421197, -0.48935914, 5.48284435, -7.69781828, 11.8203125,
|
||||
1.81672478, -1.42535269, -5.26496315, -5.31612349, -4.19499826,
|
||||
7.06049395, 0.18029356, -0.0519902706, 10.317358, 2.19345617, 3.5296216},
|
||||
{7.52353811, 3.56836724, 0.414305687, 0.340799928, 2.44263697, 7.52111912,
|
||||
0.246491909, -11.1172791, -3.82061529, 3.24794388, 0.751524329,
|
||||
3.14019632, 6.33881855, -0.169233799, 7.82640171, 1.5389179, 8.15851307}},
|
||||
|
|
|
|||
|
|
@ -112,7 +112,9 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
const LayerConfig& layer_config = config.layer_configs[0];
|
||||
const LayerWeightsPtrs layers(0, layer_config, tensor_info_registry);
|
||||
InferenceArgs inference_args;
|
||||
inference_args.attention_impl = "flash";
|
||||
RuntimeConfig runtime_config;
|
||||
inference_args.CopyTo(runtime_config);
|
||||
KVCache kv_cache(config, inference_args, ctx.allocator);
|
||||
MatMulEnv env(ctx);
|
||||
Activations activations(runtime_config, config,
|
||||
|
|
@ -127,8 +129,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
const size_t batch_size = kOuter;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||
kOuter, AttentionImpl::kFlash,
|
||||
ctx.allocator, row_ptrs);
|
||||
kOuter, runtime_config, ctx.allocator,
|
||||
row_ptrs);
|
||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
ASSERT_EQ(qkv_dim, kInner);
|
||||
|
|
|
|||
Loading…
Reference in New Issue