diff --git a/BUILD.bazel b/BUILD.bazel index 3ec04e7..deb376b 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -141,6 +141,7 @@ cc_test( ":gemma_args", ":gemma_lib", ":kv_cache", + ":kv_transcoding", ":mat", ":matmul", ":test_util", @@ -151,6 +152,7 @@ cc_test( "//compression:types", "@highway//:hwy", "@highway//:hwy_test_util", + "@highway//:nanobenchmark", ], ) diff --git a/gemma/flash_attention_test.cc b/gemma/flash_attention_test.cc index 64f3c0d..f271c18 100644 --- a/gemma/flash_attention_test.cc +++ b/gemma/flash_attention_test.cc @@ -24,9 +24,11 @@ #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "gemma/kv_cache.h" +#include "gemma/kv_transcoding.h" #include "gemma/weights.h" #include "ops/matmul.h" #include "util/test_util.h" +#include "hwy/nanobenchmark.h" #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS #endif // HWY_DISABLED_TARGETS @@ -102,6 +104,33 @@ void AssertClose(const MatPtrT& a, const MatPtrT& b) { } } +template +void PopulateTestKVCache(MatStorageT& kv, gcpp::KVEncoding encoding, + size_t qkv_dim) { + gcpp::DecodedTile tile(qkv_dim, gcpp::KVCache::kTileSize); + + size_t num_tiles = kv.Rows(); + float unpredictable = hwy::Unpredictable1() * 0.01f; + for (size_t tile_idx = 0; tile_idx < num_tiles; ++tile_idx) { + for (size_t in_tile = 0; in_tile < gcpp::KVCache::kTileSize; ++in_tile) { + size_t i = tile_idx * gcpp::KVCache::kTileSize + in_tile; + for (size_t j = 0; j < qkv_dim; ++j) { + tile.k_elem(in_tile, j) = unpredictable * (i + 1) / (j + 1); + tile.v_elem(in_tile, j) = unpredictable * 2 * (i + 1) / (j + 1); + } + } + size_t row_bytes = kv.Cols() * sizeof(T); + HWY_ASSERT(gcpp::EncodeTile( + encoding, tile, qkv_dim, + hwy::Span(reinterpret_cast(kv.Row(tile_idx)), row_bytes))); + } +} + +struct AttentionTestEnv { + AttentionTestEnv(size_t num_queries, size_t kv_seq_len, size_t qkv_dim, + AttentionImpl attention_impl); +}; + void TestFlashAttention(size_t target_parallelism, AttentionImpl attention_impl) { ThreadingArgs threading_args; @@ -283,39 +312,29 @@ const std::vector att_out_gold = { 0.009653}; void TestTiledFlashAttention() { - int qkv_dim = 64; - int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by - // tiles size to test the padding logic. - int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + size_t qkv_dim = 64; + size_t kv_seq_len = 60; // number of tokens we will attend to. + // Not divisible by tiles size to test the padding logic. + size_t padded_kv_seq_len = + hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); float att_cap = 10.0f; - int num_queries = 8; - int num_queries_per_timestep = 4; - int num_tokens = num_queries / num_queries_per_timestep; - int kv_seq_end = + size_t num_queries = 8; + size_t num_queries_per_timestep = 4; + size_t num_tokens = num_queries / num_queries_per_timestep; + size_t kv_seq_end = kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - MatStorageT kv( - "kv", - Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize), - ctx.allocator, MatPadding::kPacked); - // fill in kvs with predictable, synthetic data - for (int i = 0; i < padded_kv_seq_len; ++i) { - for (int j = 0; j < qkv_dim; ++j) { - const int tile_idx = i / gcpp::KVCache::kTileSize; - const int in_tile_offset = i % gcpp::KVCache::kTileSize; - const float val_k = 0.01f * (i + 1) / (j + 1); - const float val_v = 0.02f * (i + 1) / (j + 1); - kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset] = val_k; - const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; - kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j] = val_v; - } - } + MatStorageT kv("kv", + Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, + 2 * qkv_dim * gcpp::KVCache::kTileSize), + ctx.allocator, MatPadding::kPacked); + PopulateTestKVCache(kv, gcpp::KVEncoding::kF32, qkv_dim); std::vector q_float(4 * qkv_dim); std::vector q_float2(4 * qkv_dim); // fill in qs with predictable, synthetic data - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < qkv_dim; j++) { + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < qkv_dim; j++) { float val_1 = 0.01f * (i + 1) / (j + 1); float val_2 = 0.01f * (i + 4 + 1) / (j + 1); q_float[j * 4 + i] = val_1; @@ -342,11 +361,11 @@ void TestTiledFlashAttention() { std::vector> last_pos_per_query; start_pos_per_query.reserve(num_queries); last_pos_per_query.reserve(num_queries); - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { ssize_t query_last_pos = kv_seq_end + token_idx; ssize_t query_start_pos = std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep; ++q_head_idx) { start_pos_per_query.push_back(query_start_pos); last_pos_per_query.push_back(query_last_pos); @@ -365,67 +384,43 @@ void TestTiledFlashAttention() { // and output looked good. Not ideal but should be good enough to test the // plumbing and detect regressions. PrintMatPtr(att_out); - for (int i = 0; i < num_queries; ++i) { + for (size_t i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-3f) << "i=" << i; EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i; - for (int j = 0; j < qkv_dim; ++j) { + for (size_t j = 0; j < qkv_dim; ++j) { EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-5f); } } } void TestTiledFlashAttentionBF16() { - int qkv_dim = 64; - int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by - // tiles size to test the padding logic. - int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + size_t qkv_dim = 64; + size_t kv_seq_len = 60; // number of tokens we will attend to. + // Not divisible by tiles size to test the padding logic. + size_t padded_kv_seq_len = + hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); float att_cap = 10.0f; - int num_queries = 8; - int num_queries_per_timestep = 4; - int num_tokens = num_queries / num_queries_per_timestep; - int kv_seq_end = + size_t num_queries = 8; + size_t num_queries_per_timestep = 4; + size_t num_tokens = num_queries / num_queries_per_timestep; + size_t kv_seq_end = kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - MatStorageT kv( - "kv", - Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize), - ctx.allocator, MatPadding::kPacked); - // fill in kvs with predictable, synthetic data - for (int i = 0; i < padded_kv_seq_len; i++) { - for (int j = 0; j < qkv_dim; j+=2) { - const int tile_idx = i / gcpp::KVCache::kTileSize; - const int in_tile_offset = i % gcpp::KVCache::kTileSize; - const float val_k_1 = 0.01f * (i + 1) / (j + 1); - const float val_k_2 = 0.01f * (i + 1) / (j + 2); - kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2] = - hwy::ConvertScalarTo(val_k_1); - kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2 + 1] = - hwy::ConvertScalarTo(val_k_2); - } - } - const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; - for (int i = 0; i < padded_kv_seq_len; i += 2) { - for (int j = 0; j < qkv_dim; j++) { - const int tile_idx = i / gcpp::KVCache::kTileSize; - const int in_tile_offset = i % gcpp::KVCache::kTileSize; - const float val_v_1 = 0.02f * (i + 1) / (j + 1); - const float val_v_2 = 0.02f * (i + 2) / (j + 1); - kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2] = - hwy::ConvertScalarTo(val_v_1); - kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2 + 1] = - hwy::ConvertScalarTo(val_v_2); - } - } + MatStorageT kv("kv", + Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize, + 2 * qkv_dim * gcpp::KVCache::kTileSize), + ctx.allocator, MatPadding::kPacked); + PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim); std::vector q_float(num_queries_per_timestep * qkv_dim); std::vector q_float2(num_queries_per_timestep * qkv_dim); // fill in qs with predictable, synthetic data - for (int i = 0; i < num_queries_per_timestep; ++i) { - for (int j = 0; j < qkv_dim; j += 2) { + for (size_t i = 0; i < num_queries_per_timestep; ++i) { + for (size_t j = 0; j < qkv_dim; j += 2) { q_float[j * num_queries_per_timestep + i * 2] = hwy::ConvertScalarTo(0.01f * (i + 1) / (j + 1)); q_float[j * num_queries_per_timestep + i * 2 + 1] = @@ -458,11 +453,11 @@ void TestTiledFlashAttentionBF16() { std::vector> last_pos_per_query; start_pos_per_query.reserve(num_queries); last_pos_per_query.reserve(num_queries); - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { ssize_t query_last_pos = kv_seq_end + token_idx; ssize_t query_start_pos = std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep; ++q_head_idx) { start_pos_per_query.push_back(query_start_pos); last_pos_per_query.push_back(query_last_pos); @@ -480,90 +475,47 @@ void TestTiledFlashAttentionBF16() { // and output looked good. Not ideal but should be good enough to test the // plumbing and detect regressions. PrintMatPtr(att_out); - for (int i = 0; i < num_queries; ++i) { + for (size_t i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f) << "i=" << i; EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; - for (int j = 0; j < qkv_dim; ++j) { + for (size_t j = 0; j < qkv_dim; ++j) { EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f); } } } void TestTiledFlashAttentionInt8() { - int qkv_dim = 64; - int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by - // tiles size to test the padding logic. - int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); + size_t qkv_dim = 64; + // number of tokens we will attend to. + // Not divisible by tiles size to test the padding logic. + size_t kv_seq_len = 60; + size_t padded_kv_seq_len = + hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize); float att_cap = 10.0f; - int num_queries = 8; - int num_queries_per_timestep = 4; - int num_tokens = num_queries / num_queries_per_timestep; - int kv_seq_end = + size_t num_queries = 8; + size_t num_queries_per_timestep = 4; + size_t num_tokens = num_queries / num_queries_per_timestep; + size_t kv_seq_end = kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep); ThreadingArgs threading_args; ThreadingContext ctx(threading_args); - int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; - int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + - 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; + size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize; + size_t tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize + + 2 * sizeof(BF16) * gcpp::KVCache::kTileSize; MatStorageT kv("kv", Extents2D(num_tiles, tile_size_bytes), ctx.allocator, MatPadding::kPacked); - - // fill in kvs with predictable, synthetic data - for (int i = 0; i < padded_kv_seq_len; ++i) { - int tile_idx = i / gcpp::KVCache::kTileSize; - int in_tile_offset = i % gcpp::KVCache::kTileSize; - int8_t* tile_ptr = kv.Row(tile_idx); - BF16* scales_ptr = HWY_RCAST_ALIGNED( - BF16*, tile_ptr + 2 * qkv_dim * gcpp::KVCache::kTileSize); - - // Generate float values for K and V - std::vector k_vals(qkv_dim); - std::vector v_vals(qkv_dim); - float max_abs_k = 0.0f; - float max_abs_v = 0.0f; - - for (int j = 0; j < qkv_dim; ++j) { - k_vals[j] = 0.01f * (i + 1) / (j + 1); - v_vals[j] = 0.02f * (i + 1) / (j + 1); - max_abs_k = std::max(max_abs_k, std::abs(k_vals[j])); - max_abs_v = std::max(max_abs_v, std::abs(v_vals[j])); - } - - // Quantize K - float scale_k = max_abs_k / 127.0f; - if (scale_k == 0.0f) scale_k = 1.0f; - scales_ptr[in_tile_offset] = hwy::ConvertScalarTo(scale_k); - for (int j = 0; j < qkv_dim; ++j) { - int val = std::round(k_vals[j] / scale_k); - val = std::max(-127, std::min(127, val)); - tile_ptr[j * gcpp::KVCache::kTileSize + in_tile_offset] = - static_cast(val); - } - - // Quantize V - float scale_v = max_abs_v / 127.0f; - if (scale_v == 0.0f) scale_v = 1.0f; - scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset] = - hwy::ConvertScalarTo(scale_v); - size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize; - for (int j = 0; j < qkv_dim; ++j) { - int val = std::round(v_vals[j] / scale_v); - val = std::max(-127, std::min(127, val)); - tile_ptr[v_offset + in_tile_offset * qkv_dim + j] = - static_cast(val); - } - } + PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8, qkv_dim); std::vector q_float(4 * qkv_dim); std::vector q_float2(4 * qkv_dim); // fill in qs with predictable, synthetic data - for (int i = 0; i < 4; ++i) { - for (int j = 0; j < qkv_dim; j++) { + for (size_t i = 0; i < 4; ++i) { + for (size_t j = 0; j < qkv_dim; j++) { float val_1 = 0.01f * (i + 1) / (j + 1); float val_2 = 0.01f * (i + 4 + 1) / (j + 1); q_float[j * 4 + i] = val_1; @@ -590,11 +542,11 @@ void TestTiledFlashAttentionInt8() { std::vector> last_pos_per_query; start_pos_per_query.reserve(num_queries); last_pos_per_query.reserve(num_queries); - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { ssize_t query_last_pos = kv_seq_end + token_idx; ssize_t query_start_pos = std::max(query_last_pos - 100000 + 1, static_cast(0)); - for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep; + for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep; ++q_head_idx) { start_pos_per_query.push_back(query_start_pos); last_pos_per_query.push_back(query_last_pos); @@ -613,13 +565,13 @@ void TestTiledFlashAttentionInt8() { // and output looked good. Not ideal but should be good enough to test the // plumbing and detect regressions. PrintMatPtr(att_out); - for (int i = 0; i < num_queries; ++i) { + for (size_t i = 0; i < num_queries; ++i) { std::cerr << "exp_d: " << exp_denominator_sums[i] << " max_logit: " << max_logits[i] << std::endl; - EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-2f) + EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f) << "i=" << i; EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i; - for (int j = 0; j < qkv_dim; ++j) { + for (size_t j = 0; j < qkv_dim; ++j) { EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f); } } diff --git a/gemma/tiled_attention_test.cc b/gemma/tiled_attention_test.cc index 3efdd69..79282c2 100644 --- a/gemma/tiled_attention_test.cc +++ b/gemma/tiled_attention_test.cc @@ -14,6 +14,7 @@ #include "gemma/gemma.h" #include "gemma/gemma_args.h" #include "gemma/kv_cache.h" +#include "gemma/kv_transcoding.h" #include "gemma/weights.h" #include "util/mat.h" #include "util/threading_context.h" @@ -42,9 +43,10 @@ using ::testing::Pointwise; struct AttentionTestEnv { AttentionTestEnv( - int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads, - int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx, - int layers_total, int qbatch_size, AttentionImpl attention_impl, + size_t qkv_dim, size_t kv_seq_len, size_t attention_window_size, + size_t num_kv_heads, size_t num_heads, size_t num_tokens, size_t last_pos, + float att_cap, size_t layer_idx, size_t layers_total, size_t qbatch_size, + AttentionImpl attention_impl, std::optional kv_cache_type = {} ) : ctx(threading_args), env(ctx) { layer_config.heads = num_heads; @@ -60,7 +62,7 @@ struct AttentionTestEnv { model_config.model_dim = layer_config.model_dim; model_config.vocab_size = 1; // not vit - for (int i = 0; i < model_config.num_layers; ++i) { + for (size_t i = 0; i < model_config.num_layers; ++i) { model_config.layer_configs.push_back(layer_config); } tensor_info_registry = std::make_unique(model_config); @@ -73,755 +75,27 @@ struct AttentionTestEnv { all_queries.Reserve(qbatch_size); kv_caches.reserve(qbatch_size); - for (int q = 0; q < qbatch_size; ++q) { + float unpredictable = hwy::Unpredictable1() * 0.01f; + for (size_t q = 0; q < qbatch_size; ++q) { kv_caches.emplace_back(model_config, inference_args, runtime_config, ctx.allocator); - if (attention_impl == AttentionImpl::kFlashTransposedQsBF16 && - kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kBF16) { - MatPtrT compact_kv_cache = kv_caches.back().compact_kv_cache_ptr; - for (int i = 0; i < compact_kv_cache.Rows(); ++i) { - for (int j = 0; j < compact_kv_cache.Cols(); ++j) { - BF16 val = hwy::ConvertScalarTo(hwy::Unpredictable1() * - 0.01f * (i + j + 1)); - // split j into if k/v - if (j < qkv_dim * gcpp::KVCache::kTileSize) { - // split j into dim and in tile offset - const int dim = j / gcpp::KVCache::kTileSize; - const int in_tile_offset = j % gcpp::KVCache::kTileSize; - const int dim_mod_2 = dim % 2; - compact_kv_cache.Row( - i)[(dim - dim_mod_2) * gcpp::KVCache::kTileSize + - in_tile_offset * 2 + dim_mod_2] = val; - } else { - const int in_tile_offset = j / qkv_dim; - const int dim = j % qkv_dim; - const int in_tile_offset_mod_2 = in_tile_offset % 2; - compact_kv_cache.Row( - i)[(in_tile_offset - in_tile_offset_mod_2) * qkv_dim + - dim * 2 + in_tile_offset_mod_2] = val; + if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) { + const size_t tile_size = gcpp::KVCache::kTileSize; + gcpp::DecodedTile decoded(qkv_dim, tile_size); + for (size_t i = 0; i < kv_caches.back().compact_kv_cache_ptr.Rows(); + ++i) { + for (size_t token = 0; token < tile_size; ++token) { + for (size_t dim = 0; dim < qkv_dim; ++dim) { + size_t j_k = dim * tile_size + token; + decoded.k_elem(token, dim) = unpredictable * (i + j_k + 1); + + size_t j_v = qkv_dim * tile_size + token * qkv_dim + dim; + decoded.v_elem(token, dim) = unpredictable * (i + j_v + 1); } } - } - } else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) { - if (kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kInt8) { - MatPtrT compact_kv_cache = - kv_caches.back().compact_kv_cache_ptr; - for (int i = 0; i < compact_kv_cache.Rows(); ++i) { - BF16* scales_ptr = HWY_RCAST_ALIGNED( - BF16*, compact_kv_cache.Row(i) + - 2 * qkv_dim * gcpp::KVCache::kTileSize); - for (int in_tile_idx = 0; in_tile_idx < gcpp::KVCache::kTileSize; - ++in_tile_idx) { - // Compute scale and fill K - float max_k = 0.0f; - for (int dim = 0; dim < qkv_dim; ++dim) { - int j = dim * gcpp::KVCache::kTileSize + in_tile_idx; - float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); - max_k = std::max(max_k, expected); - } - float scale_k = max_k / 127.0f; - if (scale_k == 0.0f) scale_k = 1.0f; - scales_ptr[in_tile_idx] = hwy::ConvertScalarTo(scale_k); - for (int dim = 0; dim < qkv_dim; ++dim) { - int j = dim * gcpp::KVCache::kTileSize + in_tile_idx; - float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); - compact_kv_cache.Row(i)[j] = - static_cast(std::round(expected / scale_k)); - } - - // Compute scale and fill V - float max_v = 0.0f; - for (int dim = 0; dim < qkv_dim; ++dim) { - int j = qkv_dim * gcpp::KVCache::kTileSize + - in_tile_idx * qkv_dim + dim; - float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); - max_v = std::max(max_v, expected); - } - float scale_v = max_v / 127.0f; - if (scale_v == 0.0f) scale_v = 1.0f; - scales_ptr[gcpp::KVCache::kTileSize + in_tile_idx] = - hwy::ConvertScalarTo(scale_v); - - for (int dim = 0; dim < qkv_dim; ++dim) { - int j = qkv_dim * gcpp::KVCache::kTileSize + - in_tile_idx * qkv_dim + dim; - float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1); - compact_kv_cache.Row(i)[j] = - static_cast(std::round(expected / scale_v)); - } - } - } - } else if (kv_caches.back().compact_kv_cache_ptr.GetType() == - Type::kBF16) { - MatPtrT compact_kv_cache = - kv_caches.back().compact_kv_cache_ptr; - FillMatPtrT(compact_kv_cache); - } else { - MatPtrT compact_kv_cache = - kv_caches.back().compact_kv_cache_ptr; - FillMatPtrT(compact_kv_cache); - } - } else { - FillMatPtrT(kv_caches.back().kv_cache); - } - all_queries.Append({ - .prompt = PromptTokens({1, 2, 3}), - .mutable_pos = static_cast(last_pos), - .initial_pos = 0, - .prefix_end = 0, - .kv_cache = kv_caches.back().ToPtr(), - }); - } - - activations = std::make_unique(runtime_config, model_config, - qbatch_size * num_tokens, - kv_seq_len, ctx, env.row_ptrs); - - qbatch = - std::make_unique(/*start_pos=*/0, qbatch_size, all_queries); - } - - void SetupWeights() { - int model_dim = layer_config.model_dim; - int qkv_dim = layer_config.qkv_dim; - int num_heads = layer_config.heads; - int num_kv_heads = layer_config.kv_heads; - - qkv1_w_storage = - MatStorageT("qkv1", Extents2D(model_dim, qkv_dim * num_heads), - ctx.allocator, MatPadding::kPacked); - qkv2_w_storage = MatStorageT( - "qkv2", Extents2D(model_dim, num_kv_heads * 2 * qkv_dim), ctx.allocator, - MatPadding::kPacked); - wo_w_storage = MatStorageT("wo", Extents2D(model_dim, model_dim), - ctx.allocator, MatPadding::kPacked); - - FillMatPtrT(wo_w_storage); - layer->att_weights = wo_w_storage; - FillMatPtrT(qkv1_w_storage); - FillMatPtrT(qkv2_w_storage); - layer->qkv_einsum_w1 = qkv1_w_storage; - layer->qkv_einsum_w2 = qkv2_w_storage; - - query_norm_scale = MatStorageT("query_norm", qkv_dim, ctx.allocator); - FillMatPtrT(query_norm_scale); - layer->query_norm_scale = query_norm_scale; - - key_norm_scale = MatStorageT("key_norm", qkv_dim, ctx.allocator); - FillMatPtrT(key_norm_scale); - layer->key_norm_scale = key_norm_scale; - } - - AttentionTestEnv(const AttentionTestEnv&) = delete; - AttentionTestEnv& operator=(const AttentionTestEnv&) = delete; - AttentionTestEnv(AttentionTestEnv&&) = delete; - AttentionTestEnv& operator=(AttentionTestEnv&&) = delete; - - ThreadingArgs threading_args; - ThreadingContext ctx; - MatMulEnv env; - LayerConfig layer_config; - ModelConfig model_config; - std::unique_ptr tensor_info_registry; - std::unique_ptr layer; - RuntimeConfig runtime_config; - InferenceArgs inference_args; - AllQueries all_queries; - std::vector kv_caches; - std::unique_ptr activations; - std::unique_ptr qbatch; - - // Weights storage for later tests - MatStorageT qkv1_w_storage; - MatStorageT qkv2_w_storage; - MatStorageT wo_w_storage; - MatStorageT query_norm_scale; - MatStorageT key_norm_scale; -}; - -void TestTransposeStridedQueries() { - ThreadingArgs threading_args; - ThreadingContext ctx(threading_args); - int qkv_dim = 64; - int num_queries = 24; - AlignedPtr input_queries = - ctx.allocator.Alloc(qkv_dim * num_queries); - AlignedPtr output_queries = - ctx.allocator.Alloc(qkv_dim * num_queries); - for (int i = 0; i < num_queries; ++i) { - for (int j = 0; j < qkv_dim; ++j) { - input_queries[i * qkv_dim + j] = i * qkv_dim + j; - } - } - std::vector queries; - for (int i = 0; i < num_queries; ++i) { - queries.push_back(input_queries.get() + i * qkv_dim); - } - hwy::Span queries_span(queries.data(), queries.size()); - - TransposeStridedQueries( - queries_span, qkv_dim, - hwy::Span(output_queries.get(), qkv_dim * num_queries)); - for (int i = 0; i < num_queries; ++i) { - for (int j = 0; j < qkv_dim; ++j) { - EXPECT_EQ(output_queries[j * num_queries + i], - input_queries[i * qkv_dim + j]) - << "i=" << i << " j=" << j; - } - } -} - -void TestLocalAttentionForAllHeadsTokensAndBatch() { - int qkv_dim = 64; - int kv_seq_len = 64; - int num_kv_heads = 2; - int num_heads = 2; - int num_tokens = 2; - int last_pos = 62; // so token 0 will have 63 and token 1 will have 64 tokens - // to attend to. - float att_cap = 10.0f; - int layer_idx = 0; - int layers_total = 1; - int qbatch_size = 2; - AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; - AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, - num_heads, num_tokens, last_pos, att_cap, layer_idx, - layers_total, qbatch_size, attention_impl); - FillMatPtrT(test_env.activations->attention.q); - LocalAttentionForAllHeadsTokensAndBatch( - attention_impl, num_tokens, layer_idx, *test_env.layer, - test_env.activations->attention, *test_env.qbatch, test_env.ctx); - - // print states; - std::vector exp_denominator_sums_gold = {63, 63, 64, 64, - 63, 63, 64, 64}; - std::vector max_logits_gold = {10, 10, 10, 10, 10, 10, 10, 10}; - std::vector att_out_gold = { - 30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, - 30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, - 30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, - 30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, - 30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, - 30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, - 30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, - 30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, - 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475, - 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275, - 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075, - 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875, - 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675, - 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475, - 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275, - 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075, - 30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, - 30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, - 30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, - 30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, - 30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, - 30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, - 30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, - 30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, - 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505, - 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585, - 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665, - 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745, - 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825, - 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905, - 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985, - 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065, - 30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, - 30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, - 30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, - 30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, - 30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, - 30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, - 30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, - 30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, - 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475, - 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275, - 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075, - 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875, - 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675, - 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475, - 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275, - 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075, - 30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, - 30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, - 30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, - 30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, - 30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, - 30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, - 30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, - 30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, - 30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505, - 30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585, - 30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665, - 30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745, - 30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825, - 30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905, - 30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985, - 30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065, - }; - const int group_size = num_heads / num_kv_heads; - for (int token_idx = 0; token_idx < num_tokens; ++token_idx) { - for (int q_batch_idx = 0; q_batch_idx < qbatch_size; ++q_batch_idx) { - int b = token_idx * qbatch_size + q_batch_idx; - EXPECT_THAT( - absl::MakeSpan(test_env.activations->attention.softmax_d.Row(b), - num_heads), - Pointwise(FloatNear(1e-3f), absl::MakeSpan(exp_denominator_sums_gold) - .subspan(b * num_heads, num_heads))); - EXPECT_THAT( - absl::MakeSpan(test_env.activations->attention.softmax_max.Row(b), - num_heads), - Pointwise(FloatNear(1e-3f), absl::MakeSpan(max_logits_gold) - .subspan(b * num_heads, num_heads))); - for (int kv_h = 0; kv_h < num_kv_heads; ++kv_h) { - for (int g = 0; g < group_size; ++g) { - const int q_h = kv_h * group_size + g; - size_t expected_q_idx = b * num_heads + q_h; - EXPECT_THAT( - absl::MakeSpan(test_env.activations->attention.att_out.Row(b) + - q_h * qkv_dim, - qkv_dim), - Pointwise(FloatNear(1e-3f), - absl::MakeSpan(att_out_gold) - .subspan(expected_q_idx * qkv_dim, qkv_dim))); - } - } - } - } -} - -const std::vector AttentionMultipleTokensAttentionGoldens = { - 34.7414, 34.7717, 34.8022, 34.8327, 34.8631, 34.8936, 34.9241, 34.9545, - 34.985, 35.0156, 35.046, 35.0765, 35.1068, 35.1373, 35.1678, 35.1982, - 35.2286, 35.2592, 35.2895, 35.32, 35.3506, 35.381, 35.4115, 35.4421, - 35.4725, 35.503, 35.5334, 35.5638, 35.5943, 35.6247, 35.6552, 35.6857, - 35.7161, 35.7466, 35.7772, 35.8076, 35.8381, 35.8685, 35.8989, 35.9294, - 35.9598, 35.9902, 36.0208, 36.0512, 36.0816, 36.1122, 36.1426, 36.1731, - 36.2037, 36.2341, 36.2646, 36.295, 36.3254, 36.356, 36.3863, 36.4168, - 36.4474, 36.4778, 36.5082, 36.5388, 36.5692, 36.5997, 36.6301, 36.6605, - 34.6687, 34.6987, 34.7288, 34.759, 34.7891, 34.8192, 34.8495, 34.8795, - 34.9097, 34.9399, 34.97, 35.0002, 35.0302, 35.0604, 35.0906, 35.1206, - 35.1507, 35.181, 35.211, 35.2412, 35.2714, 35.3015, 35.3317, 35.3619, - 35.3921, 35.4222, 35.4523, 35.4824, 35.5126, 35.5427, 35.5728, 35.603, - 35.6331, 35.6633, 35.6935, 35.7236, 35.7538, 35.7838, 35.814, 35.8442, - 35.8742, 35.9043, 35.9346, 35.9646, 35.9948, 36.025, 36.0551, 36.0853, - 36.1155, 36.1456, 36.1759, 36.2059, 36.236, 36.2662, 36.2963, 36.3264, - 36.3566, 36.3867, 36.4169, 36.4471, 36.4772, 36.5074, 36.5374, 36.5676, - 37.0338, 37.0634, 37.0929, 37.1222, 37.1519, 37.1813, 37.2107, 37.2403, - 37.2698, 37.2992, 37.3288, 37.3584, 37.3877, 37.4174, 37.447, 37.4764, - 37.5056, 37.5352, 37.5646, 37.5938, 37.6234, 37.6528, 37.6821, 37.7117, - 37.7412, 37.7705, 37.8001, 37.8295, 37.8589, 37.8885, 37.918, 37.9473, - 37.977, 38.0065, 38.0358, 38.0655, 38.095, 38.1244, 38.1541, 38.1836, - 38.213, 38.2422, 38.2718, 38.3012, 38.3305, 38.36, 38.3895, 38.4187, - 38.4484, 38.4778, 38.5071, 38.5367, 38.5662, 38.5955, 38.6251, 38.6546, - 38.6839, 38.7136, 38.7431, 38.7725, 38.8021, 38.8316, 38.861, 38.8907, - 36.9872, 37.0167, 37.046, 37.0752, 37.1047, 37.1341, 37.1633, 37.1928, - 37.2222, 37.2514, 37.2809, 37.3103, 37.3396, 37.3691, 37.3985, 37.4278, - 37.4569, 37.4863, 37.5156, 37.5447, 37.5742, 37.6035, 37.6326, 37.6621, - 37.6914, 37.7206, 37.7501, 37.7794, 37.8086, 37.8381, 37.8674, 37.8966, - 37.9262, 37.9555, 37.9848, 38.0143, 38.0437, 38.0729, 38.1025, 38.1319, - 38.1612, 38.1903, 38.2197, 38.249, 38.2781, 38.3075, 38.3368, 38.366, - 38.3955, 38.4248, 38.4539, 38.4834, 38.5127, 38.5419, 38.5714, 38.6008, - 38.63, 38.6595, 38.6889, 38.7181, 38.7477, 38.777, 38.8063, 38.8358, - 39.0984, 39.1479, 39.1976, 39.2475, 39.297, 39.3468, 39.3967, 39.4463, - 39.4961, 39.546, 39.5957, 39.6455, 39.695, 39.7447, 39.7946, 39.8441, - 39.8939, 39.9438, 39.9934, 40.0431, 40.0931, 40.1427, 40.1925, 40.2425, - 40.2921, 40.342, 40.3915, 40.4412, 40.4911, 40.5407, 40.5904, 40.6403, - 40.6899, 40.7397, 40.7897, 40.8393, 40.8892, 40.9387, 40.9884, 41.0382, - 41.0878, 41.1375, 41.1874, 41.237, 41.2868, 41.3367, 41.3863, 41.4361, - 41.4861, 41.5358, 41.5856, 41.6351, 41.6849, 41.7347, 41.7843, 41.834, - 41.884, 41.9336, 41.9834, 42.0333, 42.083, 42.1328, 42.1823, 42.232, - 38.9699, 39.0188, 39.068, 39.1173, 39.1663, 39.2155, 39.2648, 39.3138, - 39.3631, 39.4124, 39.4615, 39.5108, 39.5597, 39.6089, 39.6581, 39.7071, - 39.7563, 39.8056, 39.8546, 39.9039, 39.9532, 40.0023, 40.0515, 40.1009, - 40.15, 40.1993, 40.2483, 40.2974, 40.3467, 40.3957, 40.4449, 40.4942, - 40.5433, 40.5925, 40.6419, 40.691, 40.7402, 40.7892, 40.8383, 40.8876, - 40.9366, 40.9857, 41.035, 41.0841, 41.1333, 41.1826, 41.2317, 41.2809, - 41.3303, 41.3794, 41.4287, 41.4777, 41.5268, 41.5761, 41.6251, 41.6743, - 41.7237, 41.7727, 41.8219, 41.8713, 41.9204, 41.9697, 42.0186, 42.0677, - 43.4945, 43.5425, 43.5902, 43.6376, 43.6856, 43.7334, 43.7808, 43.8289, - 43.8766, 43.9241, 43.9722, 44.02, 44.0675, 44.1157, 44.1635, 44.2111, - 44.2583, 44.3062, 44.3538, 44.4011, 44.449, 44.4966, 44.544, 44.5919, - 44.6396, 44.6869, 44.735, 44.7826, 44.8301, 44.8781, 44.9258, 44.9733, - 45.0213, 45.0691, 45.1166, 45.1647, 45.2125, 45.26, 45.3081, 45.356, - 45.4035, 45.4508, 45.4987, 45.5462, 45.5936, 45.6415, 45.6891, 45.7364, - 45.7844, 45.832, 45.8794, 45.9274, 45.9751, 46.0225, 46.0705, 46.1183, - 46.1657, 46.2138, 46.2615, 46.309, 46.3571, 46.4049, 46.4525, 46.5006, - 43.4125, 43.4603, 43.5077, 43.5549, 43.6027, 43.6502, 43.6974, 43.7453, - 43.7928, 43.84, 43.8879, 43.9355, 43.9828, 44.0307, 44.0783, 44.1256, - 44.1726, 44.2203, 44.2676, 44.3147, 44.3624, 44.4098, 44.4569, 44.5046, - 44.552, 44.5992, 44.6469, 44.6944, 44.7416, 44.7894, 44.8369, 44.8841, - 44.9319, 44.9795, 45.0267, 45.0746, 45.1222, 45.1694, 45.2173, 45.265, - 45.3123, 45.3593, 45.407, 45.4543, 45.5014, 45.5491, 45.5965, 45.6436, - 45.6913, 45.7387, 45.7859, 45.8336, 45.8811, 45.9283, 45.9761, 46.0236, - 46.0708, 46.1186, 46.1661, 46.2134, 46.2613, 46.3088, 46.3561, 46.404, - 34.7729, 34.8035, 34.8341, 34.8648, 34.8953, 34.9259, 34.9567, 34.9872, - 35.0179, 35.0486, 35.0792, 35.1098, 35.1404, 35.171, 35.2016, 35.2322, - 35.2628, 35.2935, 35.324, 35.3547, 35.3854, 35.416, 35.4466, 35.4774, - 35.508, 35.5387, 35.5692, 35.5998, 35.6305, 35.661, 35.6916, 35.7224, - 35.7529, 35.7836, 35.8143, 35.8449, 35.8755, 35.9061, 35.9367, 35.9674, - 35.9979, 36.0285, 36.0592, 36.0898, 36.1204, 36.1511, 36.1817, 36.2123, - 36.2431, 36.2737, 36.3044, 36.3349, 36.3655, 36.3962, 36.4267, 36.4574, - 36.4881, 36.5186, 36.5493, 36.58, 36.6106, 36.6413, 36.6718, 36.7024, - 34.6995, 34.7297, 34.76, 34.7904, 34.8206, 34.8509, 34.8813, 34.9115, - 34.9418, 34.9722, 35.0025, 35.0328, 35.063, 35.0933, 35.1237, 35.1539, - 35.1842, 35.2146, 35.2448, 35.2751, 35.3055, 35.3357, 35.3661, 35.3965, - 35.4268, 35.4571, 35.4873, 35.5176, 35.548, 35.5782, 35.6085, 35.6389, - 35.6691, 35.6994, 35.7298, 35.7601, 35.7904, 35.8206, 35.8509, 35.8813, - 35.9115, 35.9418, 35.9721, 36.0024, 36.0327, 36.0631, 36.0933, 36.1237, - 36.1541, 36.1843, 36.2147, 36.2449, 36.2752, 36.3056, 36.3358, 36.3661, - 36.3965, 36.4267, 36.457, 36.4874, 36.5177, 36.548, 36.5782, 36.6085, - 37.0829, 37.1127, 37.1423, 37.1717, 37.2015, 37.2312, 37.2607, 37.2905, - 37.3201, 37.3496, 37.3795, 37.4091, 37.4386, 37.4685, 37.4982, 37.5277, - 37.5571, 37.5868, 37.6164, 37.6458, 37.6755, 37.7051, 37.7346, 37.7643, - 37.7939, 37.8234, 37.8531, 37.8827, 37.9122, 37.942, 37.9716, 38.0011, - 38.0309, 38.0606, 38.0901, 38.1199, 38.1496, 38.1791, 38.209, 38.2387, - 38.2682, 38.2976, 38.3273, 38.3569, 38.3863, 38.416, 38.4456, 38.475, - 38.5048, 38.5344, 38.5638, 38.5936, 38.6232, 38.6527, 38.6825, 38.7121, - 38.7416, 38.7714, 38.8011, 38.8306, 38.8604, 38.8901, 38.9196, 38.9494, - 37.0359, 37.0655, 37.095, 37.1243, 37.154, 37.1835, 37.2129, 37.2425, - 37.2721, 37.3014, 37.3311, 37.3607, 37.39, 37.4198, 37.4493, 37.4787, - 37.508, 37.5376, 37.567, 37.5963, 37.6259, 37.6553, 37.6846, 37.7142, - 37.7437, 37.773, 37.8027, 37.8322, 37.8615, 37.8911, 37.9207, 37.95, - 37.9797, 38.0092, 38.0386, 38.0683, 38.0978, 38.1272, 38.1569, 38.1865, - 38.2159, 38.2451, 38.2747, 38.3042, 38.3334, 38.363, 38.3925, 38.4218, - 38.4514, 38.4809, 38.5102, 38.5398, 38.5693, 38.5986, 38.6283, 38.6578, - 38.6872, 38.7168, 38.7464, 38.7757, 38.8054, 38.835, 38.8644, 38.8941, - 39.1594, 39.2093, 39.2593, 39.3095, 39.3594, 39.4094, 39.4597, 39.5096, - 39.5597, 39.61, 39.6599, 39.7101, 39.7599, 39.8099, 39.8601, 39.91, - 39.96, 40.0102, 40.0601, 40.1102, 40.1605, 40.2104, 40.2605, 40.3108, - 40.3608, 40.411, 40.4608, 40.5108, 40.561, 40.6109, 40.661, 40.7112, - 40.7611, 40.8112, 40.8615, 40.9115, 40.9616, 41.0114, 41.0614, 41.1116, - 41.1615, 41.2115, 41.2617, 41.3116, 41.3617, 41.412, 41.4619, 41.512, - 41.5624, 41.6123, 41.6625, 41.7123, 41.7623, 41.8126, 41.8624, 41.9125, - 41.9627, 42.0127, 42.0628, 42.113, 42.163, 42.2131, 42.263, 42.313, - 39.0297, 39.079, 39.1284, 39.1781, 39.2274, 39.2769, 39.3265, 39.3759, - 39.4254, 39.4751, 39.5245, 39.5741, 39.6233, 39.6727, 39.7224, 39.7716, - 39.8211, 39.8708, 39.9201, 39.9696, 40.0193, 40.0686, 40.1182, 40.1679, - 40.2173, 40.2669, 40.3162, 40.3656, 40.4153, 40.4646, 40.514, 40.5637, - 40.6131, 40.6626, 40.7123, 40.7617, 40.8112, 40.8605, 40.9099, 40.9595, - 41.0088, 41.0583, 41.1079, 41.1573, 41.2068, 41.2565, 41.3058, 41.3554, - 41.4051, 41.4545, 41.5041, 41.5534, 41.6028, 41.6524, 41.7017, 41.7512, - 41.8009, 41.8502, 41.8998, 41.9495, 41.9988, 42.0484, 42.0977, 42.1471, - 43.5891, 43.6374, 43.6854, 43.7331, 43.7814, 43.8294, 43.8772, 43.9255, - 43.9736, 44.0214, 44.0698, 44.1179, 44.1657, 44.2141, 44.2623, 44.3101, - 44.3577, 44.4058, 44.4537, 44.5013, 44.5495, 44.5974, 44.6451, 44.6933, - 44.7413, 44.7889, 44.8372, 44.8852, 44.9329, 44.9812, 45.0293, 45.077, - 45.1254, 45.1734, 45.2212, 45.2696, 45.3177, 45.3655, 45.414, 45.4621, - 45.5099, 45.5575, 45.6057, 45.6535, 45.7011, 45.7493, 45.7973, 45.8449, - 45.8931, 45.9411, 45.9888, 46.037, 46.085, 46.1327, 46.1811, 46.2291, - 46.2768, 46.3252, 46.3733, 46.421, 46.4694, 46.5175, 46.5653, 46.6138, - 43.5064, 43.5544, 43.6022, 43.6497, 43.6978, 43.7456, 43.7931, 43.8412, - 43.889, 43.9366, 43.9847, 44.0326, 44.0802, 44.1284, 44.1763, 44.2239, - 44.2712, 44.3191, 44.3668, 44.4141, 44.4621, 44.5098, 44.5572, 44.6052, - 44.6529, 44.7004, 44.7484, 44.7962, 44.8436, 44.8918, 44.9395, 44.987, - 45.0352, 45.083, 45.1305, 45.1787, 45.2266, 45.2742, 45.3223, 45.3703, - 45.4179, 45.4652, 45.5131, 45.5608, 45.6081, 45.6561, 45.7038, 45.7512, - 45.7992, 45.8469, 45.8944, 45.9424, 45.9902, 46.0376, 46.0857, 46.1335, - 46.181, 46.2292, 46.277, 46.3245, 46.3727, 46.4206, 46.4682, 46.5164, -}; - -void TestAttentionMultipleTokens() { - int qkv_dim = 64; - int kv_seq_len = 64; - int num_kv_heads = 2; - int num_heads = 4; - int num_tokens = 2; - int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 - // will have 64 tokens to attend to. - float att_cap = 10.0f; - int layer_idx = 0; - int layers_total = 1; - int qbatch_size = 2; - AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; - AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, - num_heads, num_tokens, last_pos, att_cap, layer_idx, - layers_total, qbatch_size, attention_impl); - test_env.SetupWeights(); - FillMatPtrT(test_env.activations->attention.pre_att_rms_out); - FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.att); - FillMatPtrT(test_env.activations->attention.att_out); - FillMatPtrT(test_env.activations->attention.softmax_max); - FillMatPtrT(test_env.activations->attention.softmax_d); - - int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); - TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, - test_env.activations->attention, *test_env.qbatch, - test_env.env, flags); - - std::cerr << "att_out\n"; - PrintMatPtr(test_env.activations->attention.att_out); - for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { - EXPECT_TRUE(hwy::CompareArraySimilar( - AttentionMultipleTokensAttentionGoldens.data() + - i * test_env.activations->attention.att_out.Cols(), - test_env.activations->attention.att_out.Row(i), - test_env.activations->attention.att_out.Cols(), 1e-3, - hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) - << "att_out mismatch for query: " << i; - } -} - -void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() { - int qkv_dim = 64; - int kv_seq_len = 34; - int num_kv_heads = 2; - int num_heads = 4; - int num_tokens = 2; - int last_pos = 31; // so in the tbatch token 0 will have 63 and token 1 - // will have 64 tokens to attend to. - float att_cap = 10.0f; - int layer_idx = 0; - int layers_total = 1; - int qbatch_size = 2; - int attention_window_size = 32; - AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs; - AttentionTestEnv test_env(qkv_dim, kv_seq_len, attention_window_size, - num_kv_heads, num_heads, num_tokens, last_pos, - att_cap, layer_idx, layers_total, qbatch_size, - attention_impl); - test_env.SetupWeights(); - FillMatPtrT(test_env.activations->attention.pre_att_rms_out); - FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.att); - FillMatPtrT(test_env.activations->attention.att_out); - FillMatPtrT(test_env.activations->attention.softmax_max); - FillMatPtrT(test_env.activations->attention.softmax_d); - - int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); - TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, - test_env.activations->attention, *test_env.qbatch, - test_env.env, flags); - - std::cerr << "att_out\n"; - std::vector att_out_golden_test_local = { - 39.3051, 39.3556, 39.4062, 39.4571, 39.5075, 39.5582, 39.6091, 39.6596, - 39.7103, 39.7612, 39.8118, 39.8626, 39.913, 39.9636, 40.0144, 40.0649, - 40.1155, 40.1664, 40.2169, 40.2676, 40.3185, 40.369, 40.4198, 40.4707, - 40.5213, 40.572, 40.6225, 40.6731, 40.724, 40.7744, 40.8251, 40.876, - 40.9265, 40.9772, 41.0281, 41.0787, 41.1295, 41.1799, 41.2305, 41.2813, - 41.3318, 41.3824, 41.4333, 41.4838, 41.5345, 41.5854, 41.6359, 41.6867, - 41.7376, 41.7882, 41.839, 41.8894, 41.94, 41.9908, 42.0413, 42.092, - 42.1429, 42.1934, 42.2441, 42.295, 42.3456, 42.3964, 42.4468, 42.4974, - 39.1614, 39.2113, 39.2613, 39.3114, 39.3613, 39.4113, 39.4616, 39.5115, - 39.5616, 39.6118, 39.6618, 39.7119, 39.7617, 39.8117, 39.8618, 39.9117, - 39.9617, 40.0119, 40.0618, 40.1118, 40.1621, 40.212, 40.2621, 40.3124, - 40.3623, 40.4125, 40.4623, 40.5123, 40.5625, 40.6123, 40.6624, 40.7126, - 40.7625, 40.8126, 40.8629, 40.9128, 40.9629, 41.0127, 41.0627, 41.1129, - 41.1627, 41.2127, 41.2629, 41.3128, 41.3629, 41.4131, 41.463, 41.5131, - 41.5634, 41.6134, 41.6635, 41.7133, 41.7634, 41.8135, 41.8634, 41.9134, - 41.9637, 42.0135, 42.0636, 42.1139, 42.1638, 42.214, 42.2637, 42.3137, - 43.8459, 43.895, 43.9437, 43.9921, 44.0411, 44.0898, 44.1383, 44.1874, - 44.2361, 44.2846, 44.3337, 44.3825, 44.4311, 44.4802, 44.529, 44.5776, - 44.6258, 44.6747, 44.7233, 44.7716, 44.8205, 44.8692, 44.9175, 44.9665, - 45.0151, 45.0635, 45.1125, 45.1612, 45.2096, 45.2586, 45.3074, 45.3558, - 45.4049, 45.4537, 45.5021, 45.5513, 45.6001, 45.6486, 45.6977, 45.7466, - 45.7951, 45.8434, 45.8923, 45.9409, 45.9891, 46.0381, 46.0867, 46.135, - 46.184, 46.2327, 46.281, 46.33, 46.3787, 46.4271, 46.4762, 46.5249, - 46.5733, 46.6224, 46.6712, 46.7197, 46.7688, 46.8176, 46.8661, 46.9153, - 43.7538, 43.8026, 43.851, 43.8992, 43.948, 43.9964, 44.0446, 44.0934, - 44.142, 44.1902, 44.239, 44.2876, 44.3358, 44.3847, 44.4333, 44.4816, - 44.5296, 44.5782, 44.6266, 44.6746, 44.7232, 44.7716, 44.8197, 44.8684, - 44.9168, 44.9649, 45.0136, 45.0621, 45.1102, 45.159, 45.2075, 45.2557, - 45.3045, 45.353, 45.4012, 45.4501, 45.4986, 45.5469, 45.5958, 45.6444, - 45.6927, 45.7406, 45.7893, 45.8376, 45.8856, 45.9343, 45.9827, 46.0307, - 46.0794, 46.1278, 46.1759, 46.2247, 46.2731, 46.3213, 46.3701, 46.4185, - 46.4667, 46.5155, 46.564, 46.6123, 46.6611, 46.7097, 46.7579, 46.8068, - 48.7531, 48.8438, 48.9348, 49.0262, 49.1169, 49.208, 49.2995, 49.3903, - 49.4815, 49.573, 49.6639, 49.7552, 49.8458, 49.9368, 50.0281, 50.1188, - 50.2099, 50.3013, 50.3921, 50.4832, 50.5747, 50.6656, 50.7568, 50.8484, - 50.9393, 51.0306, 51.1213, 51.2123, 51.3037, 51.3944, 51.4855, 51.577, - 51.6678, 51.759, 51.8505, 51.9414, 52.0327, 52.1233, 52.2143, 52.3056, - 52.3963, 52.4874, 52.5788, 52.6696, 52.7607, 52.8522, 52.9431, 53.0343, - 53.1259, 53.2168, 53.3081, 53.3988, 53.4898, 53.5812, 53.6719, 53.763, - 53.8545, 53.9453, 54.0365, 54.128, 54.2189, 54.3102, 54.4008, 54.4918, - 48.4943, 48.5838, 48.6737, 48.7639, 48.8535, 48.9435, 49.0338, 49.1235, - 49.2135, 49.3039, 49.3937, 49.4838, 49.5732, 49.6631, 49.7533, 49.8428, - 49.9328, 50.023, 50.1127, 50.2027, 50.293, 50.3827, 50.4728, 50.5632, - 50.653, 50.7432, 50.8327, 50.9226, 51.0128, 51.1024, 51.1924, 51.2827, - 51.3724, 51.4624, 51.5528, 51.6425, 51.7327, 51.8221, 51.912, 52.0022, - 52.0917, 52.1817, 52.2719, 52.3616, 52.4516, 52.5419, 52.6316, 52.7217, - 52.8121, 52.9019, 52.9921, 53.0816, 53.1715, 53.2617, 53.3513, 53.4413, - 53.5316, 53.6212, 53.7113, 53.8017, 53.8914, 53.9815, 54.071, 54.1609, - 57.7208, 57.8084, 57.8954, 57.9818, 58.0694, 58.1564, 58.2429, 58.3306, - 58.4177, 58.5043, 58.5921, 58.6793, 58.7659, 58.8537, 58.941, 59.0277, - 59.1137, 59.2011, 59.2878, 59.374, 59.4614, 59.5482, 59.6345, 59.722, - 59.8089, 59.8952, 59.9827, 60.0697, 60.1561, 60.2437, 60.3308, 60.4172, - 60.505, 60.5921, 60.6786, 60.7664, 60.8536, 60.9402, 61.0281, 61.1153, - 61.202, 61.2881, 61.3755, 61.4622, 61.5483, 61.6358, 61.7226, 61.8088, - 61.8963, 61.9832, 62.0695, 62.1571, 62.244, 62.3304, 62.4181, 62.5051, - 62.5916, 62.6793, 62.7664, 62.853, 62.9407, 63.0279, 63.1146, 63.2024, - 57.5554, 57.6426, 57.729, 57.815, 57.9021, 57.9887, 58.0747, 58.162, - 58.2486, 58.3347, 58.422, 58.5087, 58.5949, 58.6823, 58.7691, 58.8553, - 58.9409, 59.0278, 59.114, 59.1997, 59.2867, 59.373, 59.4588, 59.5458, - 59.6323, 59.7181, 59.8052, 59.8917, 59.9776, 60.0648, 60.1514, 60.2374, - 60.3246, 60.4113, 60.4974, 60.5847, 60.6714, 60.7576, 60.8449, 60.9317, - 61.018, 61.1036, 61.1905, 61.2767, 61.3624, 61.4494, 61.5357, 61.6215, - 61.7085, 61.7949, 61.8808, 61.9679, 62.0544, 62.1403, 62.2275, 62.3141, - 62.4001, 62.4873, 62.574, 62.66, 62.7474, 62.8341, 62.9202, 63.0076, - 39.3678, 39.4186, 39.4696, 39.5207, 39.5715, 39.6225, 39.6737, 39.7246, - 39.7756, 39.8268, 39.8777, 39.9288, 39.9796, 40.0305, 40.0816, 40.1324, - 40.1834, 40.2346, 40.2854, 40.3364, 40.3876, 40.4385, 40.4896, 40.5408, - 40.5917, 40.6428, 40.6936, 40.7446, 40.7957, 40.8466, 40.8975, 40.9487, - 40.9996, 41.0506, 41.1019, 41.1528, 41.2038, 41.2546, 41.3055, 41.3567, - 41.4075, 41.4584, 41.5096, 41.5605, 41.6115, 41.6627, 41.7136, 41.7646, - 41.8159, 41.8668, 41.9179, 41.9687, 42.0196, 42.0708, 42.1216, 42.1726, - 42.2238, 42.2746, 42.3256, 42.3769, 42.4278, 42.4789, 42.5296, 42.5806, - 39.2228, 39.2729, 39.3232, 39.3737, 39.4239, 39.4743, 39.5248, 39.575, - 39.6254, 39.676, 39.7263, 39.7767, 39.8268, 39.8771, 39.9276, 39.9778, - 40.0281, 40.0786, 40.1288, 40.1792, 40.2298, 40.28, 40.3304, 40.381, - 40.4313, 40.4818, 40.5319, 40.5822, 40.6327, 40.6829, 40.7333, 40.7838, - 40.834, 40.8844, 40.935, 40.9853, 41.0357, 41.0858, 41.1361, 41.1866, - 41.2368, 41.2871, 41.3376, 41.3878, 41.4382, 41.4888, 41.539, 41.5894, - 41.64, 41.6903, 41.7408, 41.7909, 41.8412, 41.8917, 41.9419, 41.9922, - 42.0428, 42.093, 42.1434, 42.194, 42.2442, 42.2947, 42.3448, 42.3951, - 43.9435, 43.9928, 44.0418, 44.0905, 44.1399, 44.1889, 44.2376, 44.287, - 44.3361, 44.3849, 44.4343, 44.4834, 44.5322, 44.5817, 44.6308, 44.6797, - 44.7283, 44.7774, 44.8263, 44.8749, 44.9241, 44.9731, 45.0217, 45.071, - 45.12, 45.1686, 45.2179, 45.2669, 45.3156, 45.365, 45.414, 45.4628, - 45.5122, 45.5613, 45.61, 45.6595, 45.7086, 45.7574, 45.8068, 45.856, - 45.9048, 45.9534, 46.0026, 46.0515, 46.1001, 46.1493, 46.1982, 46.2469, - 46.2961, 46.3451, 46.3938, 46.4431, 46.4921, 46.5408, 46.5901, 46.6392, - 46.6879, 46.7373, 46.7864, 46.8352, 46.8846, 46.9337, 46.9825, 47.032, - 43.8506, 43.8996, 43.9484, 43.9968, 44.0459, 44.0947, 44.1432, 44.1923, - 44.2411, 44.2896, 44.3388, 44.3876, 44.4362, 44.4854, 44.5343, 44.5829, - 44.6312, 44.6801, 44.7287, 44.7771, 44.826, 44.8747, 44.9231, 44.9721, - 45.0208, 45.0692, 45.1182, 45.167, 45.2154, 45.2645, 45.3133, 45.3617, - 45.4109, 45.4597, 45.5082, 45.5574, 45.6062, 45.6548, 45.704, 45.7529, - 45.8015, 45.8498, 45.8987, 45.9473, 45.9957, 46.0446, 46.0933, 46.1416, - 46.1906, 46.2394, 46.2878, 46.3368, 46.3856, 46.434, 46.4831, 46.5319, - 46.5803, 46.6295, 46.6783, 46.7268, 46.776, 46.8248, 46.8734, 46.9226, - 48.8777, 48.969, 49.0607, 49.1527, 49.2441, 49.3358, 49.4279, 49.5194, - 49.6112, 49.7034, 49.7949, 49.8868, 49.9781, 50.0697, 50.1617, 50.2531, - 50.3448, 50.4368, 50.5283, 50.62, 50.7122, 50.8037, 50.8956, 50.9878, - 51.0794, 51.1713, 51.2626, 51.3543, 51.4463, 51.5377, 51.6294, 51.7215, - 51.813, 51.9048, 51.997, 52.0885, 52.1805, 52.2717, 52.3633, 52.4553, - 52.5467, 52.6384, 52.7305, 52.8219, 52.9137, 53.0058, 53.0973, 53.1892, - 53.2814, 53.373, 53.4649, 53.5562, 53.6479, 53.7399, 53.8313, 53.923, - 54.0152, 54.1066, 54.1984, 54.2906, 54.3821, 54.4741, 54.5653, 54.6569, - 48.6164, 48.7066, 48.7971, 48.888, 48.9782, 49.0688, 49.1597, 49.25, - 49.3407, 49.4317, 49.5221, 49.6129, 49.703, 49.7934, 49.8843, 49.9745, - 50.065, 50.1559, 50.2462, 50.3368, 50.4278, 50.5181, 50.6089, 50.6999, - 50.7903, 50.8811, 50.9713, 51.0618, 51.1527, 51.2429, 51.3335, 51.4244, - 51.5147, 51.6054, 51.6964, 51.7868, 51.8776, 51.9677, 52.0581, 52.149, - 52.2392, 52.3297, 52.4206, 52.5109, 52.6015, 52.6925, 52.7828, 52.8736, - 52.9646, 53.055, 53.1458, 53.236, 53.3265, 53.4174, 53.5076, 53.5982, - 53.6891, 53.7794, 53.8701, 53.9611, 54.0515, 54.1423, 54.2324, 54.3228, - 57.914, 58.0021, 58.0897, 58.1767, 58.265, 58.3526, 58.4397, 58.528, - 58.6157, 58.7028, 58.7912, 58.879, 58.9662, 59.0547, 59.1426, 59.2299, - 59.3165, 59.4045, 59.4918, 59.5786, 59.6666, 59.754, 59.8408, 59.9289, - 60.0165, 60.1033, 60.1915, 60.2791, 60.3661, 60.4544, 60.542, 60.629, - 60.7174, 60.8051, 60.8922, 60.9806, 61.0684, 61.1556, 61.2441, 61.332, - 61.4193, 61.5059, 61.5939, 61.6812, 61.768, 61.856, 61.9434, 62.0302, - 62.1183, 62.2059, 62.2927, 62.3809, 62.4685, 62.5555, 62.6437, 62.7314, - 62.8184, 62.9068, 62.9945, 63.0816, 63.17, 63.2578, 63.345, 63.4335, - 57.7471, 57.8348, 57.9219, 58.0084, 58.0962, 58.1834, 58.27, 58.3578, - 58.4451, 58.5317, 58.6197, 58.707, 58.7937, 58.8817, 58.9691, 59.0559, - 59.1421, 59.2296, 59.3165, 59.4028, 59.4903, 59.5773, 59.6636, 59.7512, - 59.8383, 59.9247, 60.0124, 60.0995, 60.186, 60.2738, 60.361, 60.4476, - 60.5354, 60.6227, 60.7093, 60.7973, 60.8846, 60.9713, 61.0593, 61.1467, - 61.2335, 61.3197, 61.4072, 61.4941, 61.5804, 61.6679, 61.7549, 61.8412, - 61.9289, 62.0159, 62.1023, 62.19, 62.2772, 62.3636, 62.4514, 62.5386, - 62.6252, 62.7131, 62.8003, 62.887, 62.9749, 63.0622, 63.1489, 63.237}; - PrintMatPtr(test_env.activations->attention.att_out); - for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { - EXPECT_TRUE(hwy::CompareArraySimilar( - att_out_golden_test_local.data() + - i * test_env.activations->attention.att_out.Cols(), - test_env.activations->attention.att_out.Row(i), - test_env.activations->attention.att_out.Cols(), 1e-3, - hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) - << "att_out mismatch for query: " << i; - } -} - -void TestAttentionMultipleTokensBF16() { - int qkv_dim = 64; - int kv_seq_len = 64; - int num_kv_heads = 2; - int num_heads = 4; - int num_tokens = 2; - int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 - // will have 64 tokens to attend to. - float att_cap = 10.0f; - int layer_idx = 0; - int layers_total = 1; - int qbatch_size = 2; - AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16; - AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, - num_heads, num_tokens, last_pos, att_cap, layer_idx, - layers_total, qbatch_size, attention_impl); - test_env.SetupWeights(); - FillMatPtrT(test_env.activations->attention.pre_att_rms_out); - FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.att); - FillMatPtrT(test_env.activations->attention.att_out); - FillMatPtrT(test_env.activations->attention.softmax_max); - FillMatPtrT(test_env.activations->attention.softmax_d); - - int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); - TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, - test_env.activations->attention, *test_env.qbatch, - test_env.env, flags); - std::cerr << "att_out\n"; - PrintMatPtr(test_env.activations->attention.att_out); - for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { - EXPECT_TRUE(hwy::CompareArraySimilar( - AttentionMultipleTokensAttentionGoldens.data() + - i * test_env.activations->attention.att_out.Cols(), - test_env.activations->attention.att_out.Row(i), - test_env.activations->attention.att_out.Cols(), 1e-1, - hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) - << "att_out mismatch for query: " << i; - } -} - -void TestAttentionMultipleTokensInt8() { - int qkv_dim = 64; - int kv_seq_len = 64; - int num_kv_heads = 2; - int num_heads = 4; - int num_tokens = 2; - int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1 - // will have 64 tokens to attend to. - float att_cap = 10.0f; - int layer_idx = 0; - int layers_total = 1; - int qbatch_size = 2; - AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16; - AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads, - num_heads, num_tokens, last_pos, att_cap, layer_idx, - layers_total, qbatch_size, attention_impl, - Type::kInt8); - test_env.SetupWeights(); - FillMatPtrT(test_env.activations->attention.pre_att_rms_out); - FillMatPtrT(test_env.activations->attention.q); - FillMatPtrT(test_env.activations->attention.att); - FillMatPtrT(test_env.activations->attention.att_out); - FillMatPtrT(test_env.activations->attention.softmax_max); - FillMatPtrT(test_env.activations->attention.softmax_d); - - int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16); - TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer, - test_env.activations->attention, *test_env.qbatch, - test_env.env, flags); - std::cerr << "att_out\n"; - PrintMatPtr(test_env.activations->attention.att_out); - for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) { - EXPECT_TRUE(hwy::CompareArraySimilar( - AttentionMultipleTokensAttentionGoldens.data() + - i * test_env.activations->attention.att_out.Cols(), - test_env.activations->attention.att_out.Row(i), - test_env.activations->attention.att_out.Cols(), 1e-1, - hwy::TargetName(HWY_TARGET), __FILE__, __LINE__)) - << "att_out mismatch for query: " << i; - } -} + bool transposed = + attention_impl == AttentionImpl::kFlashTransposedQsBF16 // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE