Updates to tests to use kv_transcodign library to reduce theris code size

PiperOrigin-RevId: 888600365
This commit is contained in:
Krzysztof Rymski 2026-03-24 05:05:35 -07:00 committed by Copybara-Service
parent 1dedcfd50d
commit 8a5e37eeb7
3 changed files with 114 additions and 886 deletions

View File

@ -141,6 +141,7 @@ cc_test(
":gemma_args",
":gemma_lib",
":kv_cache",
":kv_transcoding",
":mat",
":matmul",
":test_util",
@ -151,6 +152,7 @@ cc_test(
"//compression:types",
"@highway//:hwy",
"@highway//:hwy_test_util",
"@highway//:nanobenchmark",
],
)

View File

@ -24,9 +24,11 @@
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/kv_transcoding.h"
#include "gemma/weights.h"
#include "ops/matmul.h"
#include "util/test_util.h"
#include "hwy/nanobenchmark.h"
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS GEMMA_DISABLED_TARGETS
#endif // HWY_DISABLED_TARGETS
@ -102,6 +104,33 @@ void AssertClose(const MatPtrT<float>& a, const MatPtrT<float>& b) {
}
}
template <typename T>
void PopulateTestKVCache(MatStorageT<T>& kv, gcpp::KVEncoding encoding,
size_t qkv_dim) {
gcpp::DecodedTile tile(qkv_dim, gcpp::KVCache::kTileSize);
size_t num_tiles = kv.Rows();
float unpredictable = hwy::Unpredictable1() * 0.01f;
for (size_t tile_idx = 0; tile_idx < num_tiles; ++tile_idx) {
for (size_t in_tile = 0; in_tile < gcpp::KVCache::kTileSize; ++in_tile) {
size_t i = tile_idx * gcpp::KVCache::kTileSize + in_tile;
for (size_t j = 0; j < qkv_dim; ++j) {
tile.k_elem(in_tile, j) = unpredictable * (i + 1) / (j + 1);
tile.v_elem(in_tile, j) = unpredictable * 2 * (i + 1) / (j + 1);
}
}
size_t row_bytes = kv.Cols() * sizeof(T);
HWY_ASSERT(gcpp::EncodeTile(
encoding, tile, qkv_dim,
hwy::Span<char>(reinterpret_cast<char*>(kv.Row(tile_idx)), row_bytes)));
}
}
struct AttentionTestEnv {
AttentionTestEnv(size_t num_queries, size_t kv_seq_len, size_t qkv_dim,
AttentionImpl attention_impl);
};
void TestFlashAttention(size_t target_parallelism,
AttentionImpl attention_impl) {
ThreadingArgs threading_args;
@ -283,39 +312,29 @@ const std::vector<float> att_out_gold = {
0.009653};
void TestTiledFlashAttention() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
size_t qkv_dim = 64;
size_t kv_seq_len = 60; // number of tokens we will attend to.
// Not divisible by tiles size to test the padding logic.
size_t padded_kv_seq_len =
hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
size_t num_queries = 8;
size_t num_queries_per_timestep = 4;
size_t num_tokens = num_queries / num_queries_per_timestep;
size_t kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
MatStorageT<float> kv(
"kv",
Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data
for (int i = 0; i < padded_kv_seq_len; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
const int tile_idx = i / gcpp::KVCache::kTileSize;
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
const float val_k = 0.01f * (i + 1) / (j + 1);
const float val_v = 0.02f * (i + 1) / (j + 1);
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset] = val_k;
const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j] = val_v;
}
}
MatStorageT<float> kv("kv",
Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize,
2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
PopulateTestKVCache(kv, gcpp::KVEncoding::kF32, qkv_dim);
std::vector<float> q_float(4 * qkv_dim);
std::vector<float> q_float2(4 * qkv_dim);
// fill in qs with predictable, synthetic data
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < qkv_dim; j++) {
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < qkv_dim; j++) {
float val_1 = 0.01f * (i + 1) / (j + 1);
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
q_float[j * 4 + i] = val_1;
@ -342,11 +361,11 @@ void TestTiledFlashAttention() {
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
@ -365,67 +384,43 @@ void TestTiledFlashAttention() {
// and output looked good. Not ideal but should be good enough to test the
// plumbing and detect regressions.
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
for (size_t i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-3f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-6f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
for (size_t j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-5f);
}
}
}
void TestTiledFlashAttentionBF16() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
size_t qkv_dim = 64;
size_t kv_seq_len = 60; // number of tokens we will attend to.
// Not divisible by tiles size to test the padding logic.
size_t padded_kv_seq_len =
hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
size_t num_queries = 8;
size_t num_queries_per_timestep = 4;
size_t num_tokens = num_queries / num_queries_per_timestep;
size_t kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
MatStorageT<BF16> kv(
"kv",
Extents2D(padded_kv_seq_len, 2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data
for (int i = 0; i < padded_kv_seq_len; i++) {
for (int j = 0; j < qkv_dim; j+=2) {
const int tile_idx = i / gcpp::KVCache::kTileSize;
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
const float val_k_1 = 0.01f * (i + 1) / (j + 1);
const float val_k_2 = 0.01f * (i + 1) / (j + 2);
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2] =
hwy::ConvertScalarTo<BF16>(val_k_1);
kv.Row(tile_idx)[j * gcpp::KVCache::kTileSize + in_tile_offset * 2 + 1] =
hwy::ConvertScalarTo<BF16>(val_k_2);
}
}
const size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
for (int i = 0; i < padded_kv_seq_len; i += 2) {
for (int j = 0; j < qkv_dim; j++) {
const int tile_idx = i / gcpp::KVCache::kTileSize;
const int in_tile_offset = i % gcpp::KVCache::kTileSize;
const float val_v_1 = 0.02f * (i + 1) / (j + 1);
const float val_v_2 = 0.02f * (i + 2) / (j + 1);
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2] =
hwy::ConvertScalarTo<BF16>(val_v_1);
kv.Row(tile_idx)[v_offset + in_tile_offset * qkv_dim + j * 2 + 1] =
hwy::ConvertScalarTo<BF16>(val_v_2);
}
}
MatStorageT<BF16> kv("kv",
Extents2D(padded_kv_seq_len / gcpp::KVCache::kTileSize,
2 * qkv_dim * gcpp::KVCache::kTileSize),
ctx.allocator, MatPadding::kPacked);
PopulateTestKVCache(kv, gcpp::KVEncoding::kBF16TwoTranspositions, qkv_dim);
std::vector<BF16> q_float(num_queries_per_timestep * qkv_dim);
std::vector<BF16> q_float2(num_queries_per_timestep * qkv_dim);
// fill in qs with predictable, synthetic data
for (int i = 0; i < num_queries_per_timestep; ++i) {
for (int j = 0; j < qkv_dim; j += 2) {
for (size_t i = 0; i < num_queries_per_timestep; ++i) {
for (size_t j = 0; j < qkv_dim; j += 2) {
q_float[j * num_queries_per_timestep + i * 2] =
hwy::ConvertScalarTo<BF16>(0.01f * (i + 1) / (j + 1));
q_float[j * num_queries_per_timestep + i * 2 + 1] =
@ -458,11 +453,11 @@ void TestTiledFlashAttentionBF16() {
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
@ -480,90 +475,47 @@ void TestTiledFlashAttentionBF16() {
// and output looked good. Not ideal but should be good enough to test the
// plumbing and detect regressions.
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
for (size_t i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 4e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
for (size_t j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 1e-3f);
}
}
}
void TestTiledFlashAttentionInt8() {
int qkv_dim = 64;
int kv_seq_len = 60; // number of tokens we will attend to. Not divisible by
// tiles size to test the padding logic.
int padded_kv_seq_len = hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
size_t qkv_dim = 64;
// number of tokens we will attend to.
// Not divisible by tiles size to test the padding logic.
size_t kv_seq_len = 60;
size_t padded_kv_seq_len =
hwy::RoundUpTo(kv_seq_len, gcpp::KVCache::kTileSize);
float att_cap = 10.0f;
int num_queries = 8;
int num_queries_per_timestep = 4;
int num_tokens = num_queries / num_queries_per_timestep;
int kv_seq_end =
size_t num_queries = 8;
size_t num_queries_per_timestep = 4;
size_t num_tokens = num_queries / num_queries_per_timestep;
size_t kv_seq_end =
kv_seq_len - hwy::DivCeil(num_queries, num_queries_per_timestep);
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
int num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
int tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
size_t num_tiles = padded_kv_seq_len / gcpp::KVCache::kTileSize;
size_t tile_size_bytes = 2 * qkv_dim * gcpp::KVCache::kTileSize +
2 * sizeof(BF16) * gcpp::KVCache::kTileSize;
MatStorageT<int8_t> kv("kv", Extents2D(num_tiles, tile_size_bytes),
ctx.allocator, MatPadding::kPacked);
// fill in kvs with predictable, synthetic data
for (int i = 0; i < padded_kv_seq_len; ++i) {
int tile_idx = i / gcpp::KVCache::kTileSize;
int in_tile_offset = i % gcpp::KVCache::kTileSize;
int8_t* tile_ptr = kv.Row(tile_idx);
BF16* scales_ptr = HWY_RCAST_ALIGNED(
BF16*, tile_ptr + 2 * qkv_dim * gcpp::KVCache::kTileSize);
// Generate float values for K and V
std::vector<float> k_vals(qkv_dim);
std::vector<float> v_vals(qkv_dim);
float max_abs_k = 0.0f;
float max_abs_v = 0.0f;
for (int j = 0; j < qkv_dim; ++j) {
k_vals[j] = 0.01f * (i + 1) / (j + 1);
v_vals[j] = 0.02f * (i + 1) / (j + 1);
max_abs_k = std::max(max_abs_k, std::abs(k_vals[j]));
max_abs_v = std::max(max_abs_v, std::abs(v_vals[j]));
}
// Quantize K
float scale_k = max_abs_k / 127.0f;
if (scale_k == 0.0f) scale_k = 1.0f;
scales_ptr[in_tile_offset] = hwy::ConvertScalarTo<BF16>(scale_k);
for (int j = 0; j < qkv_dim; ++j) {
int val = std::round(k_vals[j] / scale_k);
val = std::max(-127, std::min(127, val));
tile_ptr[j * gcpp::KVCache::kTileSize + in_tile_offset] =
static_cast<int8_t>(val);
}
// Quantize V
float scale_v = max_abs_v / 127.0f;
if (scale_v == 0.0f) scale_v = 1.0f;
scales_ptr[gcpp::KVCache::kTileSize + in_tile_offset] =
hwy::ConvertScalarTo<BF16>(scale_v);
size_t v_offset = qkv_dim * gcpp::KVCache::kTileSize;
for (int j = 0; j < qkv_dim; ++j) {
int val = std::round(v_vals[j] / scale_v);
val = std::max(-127, std::min(127, val));
tile_ptr[v_offset + in_tile_offset * qkv_dim + j] =
static_cast<int8_t>(val);
}
}
PopulateTestKVCache(kv, gcpp::KVEncoding::kInt8, qkv_dim);
std::vector<float> q_float(4 * qkv_dim);
std::vector<float> q_float2(4 * qkv_dim);
// fill in qs with predictable, synthetic data
for (int i = 0; i < 4; ++i) {
for (int j = 0; j < qkv_dim; j++) {
for (size_t i = 0; i < 4; ++i) {
for (size_t j = 0; j < qkv_dim; j++) {
float val_1 = 0.01f * (i + 1) / (j + 1);
float val_2 = 0.01f * (i + 4 + 1) / (j + 1);
q_float[j * 4 + i] = val_1;
@ -590,11 +542,11 @@ void TestTiledFlashAttentionInt8() {
std::vector<size_t, hwy::AlignedAllocator<size_t>> last_pos_per_query;
start_pos_per_query.reserve(num_queries);
last_pos_per_query.reserve(num_queries);
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
ssize_t query_last_pos = kv_seq_end + token_idx;
ssize_t query_start_pos =
std::max(query_last_pos - 100000 + 1, static_cast<ssize_t>(0));
for (int q_head_idx = 0; q_head_idx < num_queries_per_timestep;
for (size_t q_head_idx = 0; q_head_idx < num_queries_per_timestep;
++q_head_idx) {
start_pos_per_query.push_back(query_start_pos);
last_pos_per_query.push_back(query_last_pos);
@ -613,13 +565,13 @@ void TestTiledFlashAttentionInt8() {
// and output looked good. Not ideal but should be good enough to test the
// plumbing and detect regressions.
PrintMatPtr(att_out);
for (int i = 0; i < num_queries; ++i) {
for (size_t i = 0; i < num_queries; ++i) {
std::cerr << "exp_d: " << exp_denominator_sums[i]
<< " max_logit: " << max_logits[i] << std::endl;
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 1e-2f)
EXPECT_NEAR(exp_denominator_sums[i], exp_denominator_sums_gold[i], 2e-2f)
<< "i=" << i;
EXPECT_NEAR(max_logits[i], max_logits_gold[i], 1e-3f) << "i=" << i;
for (int j = 0; j < qkv_dim; ++j) {
for (size_t j = 0; j < qkv_dim; ++j) {
EXPECT_NEAR(att_out.Row(i)[j], att_out_gold[i * qkv_dim + j], 5e-3f);
}
}

View File

@ -14,6 +14,7 @@
#include "gemma/gemma.h"
#include "gemma/gemma_args.h"
#include "gemma/kv_cache.h"
#include "gemma/kv_transcoding.h"
#include "gemma/weights.h"
#include "util/mat.h"
#include "util/threading_context.h"
@ -42,9 +43,10 @@ using ::testing::Pointwise;
struct AttentionTestEnv {
AttentionTestEnv(
int qkv_dim, int kv_seq_len, int attention_window_size, int num_kv_heads,
int num_heads, int num_tokens, int last_pos, float att_cap, int layer_idx,
int layers_total, int qbatch_size, AttentionImpl attention_impl,
size_t qkv_dim, size_t kv_seq_len, size_t attention_window_size,
size_t num_kv_heads, size_t num_heads, size_t num_tokens, size_t last_pos,
float att_cap, size_t layer_idx, size_t layers_total, size_t qbatch_size,
AttentionImpl attention_impl,
std::optional<Type> kv_cache_type = {} )
: ctx(threading_args), env(ctx) {
layer_config.heads = num_heads;
@ -60,7 +62,7 @@ struct AttentionTestEnv {
model_config.model_dim = layer_config.model_dim;
model_config.vocab_size = 1; // not vit
for (int i = 0; i < model_config.num_layers; ++i) {
for (size_t i = 0; i < model_config.num_layers; ++i) {
model_config.layer_configs.push_back(layer_config);
}
tensor_info_registry = std::make_unique<TensorInfoRegistry>(model_config);
@ -73,755 +75,27 @@ struct AttentionTestEnv {
all_queries.Reserve(qbatch_size);
kv_caches.reserve(qbatch_size);
for (int q = 0; q < qbatch_size; ++q) {
float unpredictable = hwy::Unpredictable1() * 0.01f;
for (size_t q = 0; q < qbatch_size; ++q) {
kv_caches.emplace_back(model_config, inference_args, runtime_config,
ctx.allocator);
if (attention_impl == AttentionImpl::kFlashTransposedQsBF16 &&
kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kBF16) {
MatPtrT<BF16> compact_kv_cache = kv_caches.back().compact_kv_cache_ptr;
for (int i = 0; i < compact_kv_cache.Rows(); ++i) {
for (int j = 0; j < compact_kv_cache.Cols(); ++j) {
BF16 val = hwy::ConvertScalarTo<BF16>(hwy::Unpredictable1() *
0.01f * (i + j + 1));
// split j into if k/v
if (j < qkv_dim * gcpp::KVCache::kTileSize) {
// split j into dim and in tile offset
const int dim = j / gcpp::KVCache::kTileSize;
const int in_tile_offset = j % gcpp::KVCache::kTileSize;
const int dim_mod_2 = dim % 2;
compact_kv_cache.Row(
i)[(dim - dim_mod_2) * gcpp::KVCache::kTileSize +
in_tile_offset * 2 + dim_mod_2] = val;
} else {
const int in_tile_offset = j / qkv_dim;
const int dim = j % qkv_dim;
const int in_tile_offset_mod_2 = in_tile_offset % 2;
compact_kv_cache.Row(
i)[(in_tile_offset - in_tile_offset_mod_2) * qkv_dim +
dim * 2 + in_tile_offset_mod_2] = val;
if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
const size_t tile_size = gcpp::KVCache::kTileSize;
gcpp::DecodedTile decoded(qkv_dim, tile_size);
for (size_t i = 0; i < kv_caches.back().compact_kv_cache_ptr.Rows();
++i) {
for (size_t token = 0; token < tile_size; ++token) {
for (size_t dim = 0; dim < qkv_dim; ++dim) {
size_t j_k = dim * tile_size + token;
decoded.k_elem(token, dim) = unpredictable * (i + j_k + 1);
size_t j_v = qkv_dim * tile_size + token * qkv_dim + dim;
decoded.v_elem(token, dim) = unpredictable * (i + j_v + 1);
}
}
}
} else if (kv_caches.back().compact_kv_cache_ptr.HasPtr()) {
if (kv_caches.back().compact_kv_cache_ptr.GetType() == Type::kInt8) {
MatPtrT<int8_t> compact_kv_cache =
kv_caches.back().compact_kv_cache_ptr;
for (int i = 0; i < compact_kv_cache.Rows(); ++i) {
BF16* scales_ptr = HWY_RCAST_ALIGNED(
BF16*, compact_kv_cache.Row(i) +
2 * qkv_dim * gcpp::KVCache::kTileSize);
for (int in_tile_idx = 0; in_tile_idx < gcpp::KVCache::kTileSize;
++in_tile_idx) {
// Compute scale and fill K
float max_k = 0.0f;
for (int dim = 0; dim < qkv_dim; ++dim) {
int j = dim * gcpp::KVCache::kTileSize + in_tile_idx;
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
max_k = std::max(max_k, expected);
}
float scale_k = max_k / 127.0f;
if (scale_k == 0.0f) scale_k = 1.0f;
scales_ptr[in_tile_idx] = hwy::ConvertScalarTo<BF16>(scale_k);
for (int dim = 0; dim < qkv_dim; ++dim) {
int j = dim * gcpp::KVCache::kTileSize + in_tile_idx;
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
compact_kv_cache.Row(i)[j] =
static_cast<int8_t>(std::round(expected / scale_k));
}
// Compute scale and fill V
float max_v = 0.0f;
for (int dim = 0; dim < qkv_dim; ++dim) {
int j = qkv_dim * gcpp::KVCache::kTileSize +
in_tile_idx * qkv_dim + dim;
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
max_v = std::max(max_v, expected);
}
float scale_v = max_v / 127.0f;
if (scale_v == 0.0f) scale_v = 1.0f;
scales_ptr[gcpp::KVCache::kTileSize + in_tile_idx] =
hwy::ConvertScalarTo<BF16>(scale_v);
for (int dim = 0; dim < qkv_dim; ++dim) {
int j = qkv_dim * gcpp::KVCache::kTileSize +
in_tile_idx * qkv_dim + dim;
float expected = hwy::Unpredictable1() * 0.01f * (i + j + 1);
compact_kv_cache.Row(i)[j] =
static_cast<int8_t>(std::round(expected / scale_v));
}
}
}
} else if (kv_caches.back().compact_kv_cache_ptr.GetType() ==
Type::kBF16) {
MatPtrT<BF16> compact_kv_cache =
kv_caches.back().compact_kv_cache_ptr;
FillMatPtrT(compact_kv_cache);
} else {
MatPtrT<float> compact_kv_cache =
kv_caches.back().compact_kv_cache_ptr;
FillMatPtrT(compact_kv_cache);
}
} else {
FillMatPtrT(kv_caches.back().kv_cache);
}
all_queries.Append({
.prompt = PromptTokens({1, 2, 3}),
.mutable_pos = static_cast<size_t>(last_pos),
.initial_pos = 0,
.prefix_end = 0,
.kv_cache = kv_caches.back().ToPtr(),
});
}
activations = std::make_unique<Activations>(runtime_config, model_config,
qbatch_size * num_tokens,
kv_seq_len, ctx, env.row_ptrs);
qbatch =
std::make_unique<QBatch>(/*start_pos=*/0, qbatch_size, all_queries);
}
void SetupWeights() {
int model_dim = layer_config.model_dim;
int qkv_dim = layer_config.qkv_dim;
int num_heads = layer_config.heads;
int num_kv_heads = layer_config.kv_heads;
qkv1_w_storage =
MatStorageT<float>("qkv1", Extents2D(model_dim, qkv_dim * num_heads),
ctx.allocator, MatPadding::kPacked);
qkv2_w_storage = MatStorageT<float>(
"qkv2", Extents2D(model_dim, num_kv_heads * 2 * qkv_dim), ctx.allocator,
MatPadding::kPacked);
wo_w_storage = MatStorageT<float>("wo", Extents2D(model_dim, model_dim),
ctx.allocator, MatPadding::kPacked);
FillMatPtrT(wo_w_storage);
layer->att_weights = wo_w_storage;
FillMatPtrT(qkv1_w_storage);
FillMatPtrT(qkv2_w_storage);
layer->qkv_einsum_w1 = qkv1_w_storage;
layer->qkv_einsum_w2 = qkv2_w_storage;
query_norm_scale = MatStorageT<float>("query_norm", qkv_dim, ctx.allocator);
FillMatPtrT(query_norm_scale);
layer->query_norm_scale = query_norm_scale;
key_norm_scale = MatStorageT<float>("key_norm", qkv_dim, ctx.allocator);
FillMatPtrT(key_norm_scale);
layer->key_norm_scale = key_norm_scale;
}
AttentionTestEnv(const AttentionTestEnv&) = delete;
AttentionTestEnv& operator=(const AttentionTestEnv&) = delete;
AttentionTestEnv(AttentionTestEnv&&) = delete;
AttentionTestEnv& operator=(AttentionTestEnv&&) = delete;
ThreadingArgs threading_args;
ThreadingContext ctx;
MatMulEnv env;
LayerConfig layer_config;
ModelConfig model_config;
std::unique_ptr<TensorInfoRegistry> tensor_info_registry;
std::unique_ptr<LayerWeightsPtrs> layer;
RuntimeConfig runtime_config;
InferenceArgs inference_args;
AllQueries all_queries;
std::vector<KVCache> kv_caches;
std::unique_ptr<Activations> activations;
std::unique_ptr<QBatch> qbatch;
// Weights storage for later tests
MatStorageT<float> qkv1_w_storage;
MatStorageT<float> qkv2_w_storage;
MatStorageT<float> wo_w_storage;
MatStorageT<float> query_norm_scale;
MatStorageT<float> key_norm_scale;
};
void TestTransposeStridedQueries() {
ThreadingArgs threading_args;
ThreadingContext ctx(threading_args);
int qkv_dim = 64;
int num_queries = 24;
AlignedPtr<float[]> input_queries =
ctx.allocator.Alloc<float>(qkv_dim * num_queries);
AlignedPtr<float[]> output_queries =
ctx.allocator.Alloc<float>(qkv_dim * num_queries);
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
input_queries[i * qkv_dim + j] = i * qkv_dim + j;
}
}
std::vector<float*> queries;
for (int i = 0; i < num_queries; ++i) {
queries.push_back(input_queries.get() + i * qkv_dim);
}
hwy::Span<float*> queries_span(queries.data(), queries.size());
TransposeStridedQueries(
queries_span, qkv_dim,
hwy::Span<float>(output_queries.get(), qkv_dim * num_queries));
for (int i = 0; i < num_queries; ++i) {
for (int j = 0; j < qkv_dim; ++j) {
EXPECT_EQ(output_queries[j * num_queries + i],
input_queries[i * qkv_dim + j])
<< "i=" << i << " j=" << j;
}
}
}
void TestLocalAttentionForAllHeadsTokensAndBatch() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 2;
int num_tokens = 2;
int last_pos = 62; // so token 0 will have 63 and token 1 will have 64 tokens
// to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl);
FillMatPtrT(test_env.activations->attention.q);
LocalAttentionForAllHeadsTokensAndBatch(
attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch, test_env.ctx);
// print states;
std::vector<float> exp_denominator_sums_gold = {63, 63, 64, 64,
63, 63, 64, 64};
std::vector<float> max_logits_gold = {10, 10, 10, 10, 10, 10, 10, 10};
std::vector<float> att_out_gold = {
30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275,
30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075,
30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875,
30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675,
30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475,
30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275,
30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075,
30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875,
30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475,
30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275,
30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075,
30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875,
30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675,
30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475,
30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275,
30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075,
30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485,
30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565,
30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645,
30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725,
30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805,
30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885,
30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965,
30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045,
30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505,
30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585,
30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665,
30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745,
30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825,
30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905,
30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985,
30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065,
30.2575, 30.2675, 30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275,
30.3375, 30.3475, 30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075,
30.4175, 30.4275, 30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875,
30.4975, 30.5075, 30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675,
30.5775, 30.5875, 30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475,
30.6575, 30.6675, 30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275,
30.7375, 30.7475, 30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075,
30.8175, 30.8275, 30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875,
30.2775, 30.2875, 30.2975, 30.3075, 30.3175, 30.3275, 30.3375, 30.3475,
30.3575, 30.3675, 30.3775, 30.3875, 30.3975, 30.4075, 30.4175, 30.4275,
30.4375, 30.4475, 30.4575, 30.4675, 30.4775, 30.4875, 30.4975, 30.5075,
30.5175, 30.5275, 30.5375, 30.5475, 30.5575, 30.5675, 30.5775, 30.5875,
30.5975, 30.6075, 30.6175, 30.6275, 30.6375, 30.6475, 30.6575, 30.6675,
30.6775, 30.6875, 30.6975, 30.7075, 30.7175, 30.7275, 30.7375, 30.7475,
30.7575, 30.7675, 30.7775, 30.7875, 30.7975, 30.8075, 30.8175, 30.8275,
30.8375, 30.8475, 30.8575, 30.8675, 30.8775, 30.8875, 30.8975, 30.9075,
30.415, 30.425, 30.435, 30.445, 30.455, 30.465, 30.475, 30.485,
30.495, 30.505, 30.515, 30.525, 30.535, 30.545, 30.555, 30.565,
30.575, 30.585, 30.595, 30.605, 30.615, 30.625, 30.635, 30.645,
30.655, 30.665, 30.675, 30.685, 30.695, 30.705, 30.715, 30.725,
30.735, 30.745, 30.755, 30.765, 30.775, 30.785, 30.795, 30.805,
30.815, 30.825, 30.835, 30.845, 30.855, 30.865, 30.875, 30.885,
30.895, 30.905, 30.915, 30.925, 30.935, 30.945, 30.955, 30.965,
30.975, 30.985, 30.995, 31.005, 31.015, 31.025, 31.035, 31.045,
30.435, 30.445, 30.455, 30.465, 30.475, 30.485, 30.495, 30.505,
30.515, 30.525, 30.535, 30.545, 30.555, 30.565, 30.575, 30.585,
30.595, 30.605, 30.615, 30.625, 30.635, 30.645, 30.655, 30.665,
30.675, 30.685, 30.695, 30.705, 30.715, 30.725, 30.735, 30.745,
30.755, 30.765, 30.775, 30.785, 30.795, 30.805, 30.815, 30.825,
30.835, 30.845, 30.855, 30.865, 30.875, 30.885, 30.895, 30.905,
30.915, 30.925, 30.935, 30.945, 30.955, 30.965, 30.975, 30.985,
30.995, 31.005, 31.015, 31.025, 31.035, 31.045, 31.055, 31.065,
};
const int group_size = num_heads / num_kv_heads;
for (int token_idx = 0; token_idx < num_tokens; ++token_idx) {
for (int q_batch_idx = 0; q_batch_idx < qbatch_size; ++q_batch_idx) {
int b = token_idx * qbatch_size + q_batch_idx;
EXPECT_THAT(
absl::MakeSpan(test_env.activations->attention.softmax_d.Row(b),
num_heads),
Pointwise(FloatNear(1e-3f), absl::MakeSpan(exp_denominator_sums_gold)
.subspan(b * num_heads, num_heads)));
EXPECT_THAT(
absl::MakeSpan(test_env.activations->attention.softmax_max.Row(b),
num_heads),
Pointwise(FloatNear(1e-3f), absl::MakeSpan(max_logits_gold)
.subspan(b * num_heads, num_heads)));
for (int kv_h = 0; kv_h < num_kv_heads; ++kv_h) {
for (int g = 0; g < group_size; ++g) {
const int q_h = kv_h * group_size + g;
size_t expected_q_idx = b * num_heads + q_h;
EXPECT_THAT(
absl::MakeSpan(test_env.activations->attention.att_out.Row(b) +
q_h * qkv_dim,
qkv_dim),
Pointwise(FloatNear(1e-3f),
absl::MakeSpan(att_out_gold)
.subspan(expected_q_idx * qkv_dim, qkv_dim)));
}
}
}
}
}
const std::vector<float> AttentionMultipleTokensAttentionGoldens = {
34.7414, 34.7717, 34.8022, 34.8327, 34.8631, 34.8936, 34.9241, 34.9545,
34.985, 35.0156, 35.046, 35.0765, 35.1068, 35.1373, 35.1678, 35.1982,
35.2286, 35.2592, 35.2895, 35.32, 35.3506, 35.381, 35.4115, 35.4421,
35.4725, 35.503, 35.5334, 35.5638, 35.5943, 35.6247, 35.6552, 35.6857,
35.7161, 35.7466, 35.7772, 35.8076, 35.8381, 35.8685, 35.8989, 35.9294,
35.9598, 35.9902, 36.0208, 36.0512, 36.0816, 36.1122, 36.1426, 36.1731,
36.2037, 36.2341, 36.2646, 36.295, 36.3254, 36.356, 36.3863, 36.4168,
36.4474, 36.4778, 36.5082, 36.5388, 36.5692, 36.5997, 36.6301, 36.6605,
34.6687, 34.6987, 34.7288, 34.759, 34.7891, 34.8192, 34.8495, 34.8795,
34.9097, 34.9399, 34.97, 35.0002, 35.0302, 35.0604, 35.0906, 35.1206,
35.1507, 35.181, 35.211, 35.2412, 35.2714, 35.3015, 35.3317, 35.3619,
35.3921, 35.4222, 35.4523, 35.4824, 35.5126, 35.5427, 35.5728, 35.603,
35.6331, 35.6633, 35.6935, 35.7236, 35.7538, 35.7838, 35.814, 35.8442,
35.8742, 35.9043, 35.9346, 35.9646, 35.9948, 36.025, 36.0551, 36.0853,
36.1155, 36.1456, 36.1759, 36.2059, 36.236, 36.2662, 36.2963, 36.3264,
36.3566, 36.3867, 36.4169, 36.4471, 36.4772, 36.5074, 36.5374, 36.5676,
37.0338, 37.0634, 37.0929, 37.1222, 37.1519, 37.1813, 37.2107, 37.2403,
37.2698, 37.2992, 37.3288, 37.3584, 37.3877, 37.4174, 37.447, 37.4764,
37.5056, 37.5352, 37.5646, 37.5938, 37.6234, 37.6528, 37.6821, 37.7117,
37.7412, 37.7705, 37.8001, 37.8295, 37.8589, 37.8885, 37.918, 37.9473,
37.977, 38.0065, 38.0358, 38.0655, 38.095, 38.1244, 38.1541, 38.1836,
38.213, 38.2422, 38.2718, 38.3012, 38.3305, 38.36, 38.3895, 38.4187,
38.4484, 38.4778, 38.5071, 38.5367, 38.5662, 38.5955, 38.6251, 38.6546,
38.6839, 38.7136, 38.7431, 38.7725, 38.8021, 38.8316, 38.861, 38.8907,
36.9872, 37.0167, 37.046, 37.0752, 37.1047, 37.1341, 37.1633, 37.1928,
37.2222, 37.2514, 37.2809, 37.3103, 37.3396, 37.3691, 37.3985, 37.4278,
37.4569, 37.4863, 37.5156, 37.5447, 37.5742, 37.6035, 37.6326, 37.6621,
37.6914, 37.7206, 37.7501, 37.7794, 37.8086, 37.8381, 37.8674, 37.8966,
37.9262, 37.9555, 37.9848, 38.0143, 38.0437, 38.0729, 38.1025, 38.1319,
38.1612, 38.1903, 38.2197, 38.249, 38.2781, 38.3075, 38.3368, 38.366,
38.3955, 38.4248, 38.4539, 38.4834, 38.5127, 38.5419, 38.5714, 38.6008,
38.63, 38.6595, 38.6889, 38.7181, 38.7477, 38.777, 38.8063, 38.8358,
39.0984, 39.1479, 39.1976, 39.2475, 39.297, 39.3468, 39.3967, 39.4463,
39.4961, 39.546, 39.5957, 39.6455, 39.695, 39.7447, 39.7946, 39.8441,
39.8939, 39.9438, 39.9934, 40.0431, 40.0931, 40.1427, 40.1925, 40.2425,
40.2921, 40.342, 40.3915, 40.4412, 40.4911, 40.5407, 40.5904, 40.6403,
40.6899, 40.7397, 40.7897, 40.8393, 40.8892, 40.9387, 40.9884, 41.0382,
41.0878, 41.1375, 41.1874, 41.237, 41.2868, 41.3367, 41.3863, 41.4361,
41.4861, 41.5358, 41.5856, 41.6351, 41.6849, 41.7347, 41.7843, 41.834,
41.884, 41.9336, 41.9834, 42.0333, 42.083, 42.1328, 42.1823, 42.232,
38.9699, 39.0188, 39.068, 39.1173, 39.1663, 39.2155, 39.2648, 39.3138,
39.3631, 39.4124, 39.4615, 39.5108, 39.5597, 39.6089, 39.6581, 39.7071,
39.7563, 39.8056, 39.8546, 39.9039, 39.9532, 40.0023, 40.0515, 40.1009,
40.15, 40.1993, 40.2483, 40.2974, 40.3467, 40.3957, 40.4449, 40.4942,
40.5433, 40.5925, 40.6419, 40.691, 40.7402, 40.7892, 40.8383, 40.8876,
40.9366, 40.9857, 41.035, 41.0841, 41.1333, 41.1826, 41.2317, 41.2809,
41.3303, 41.3794, 41.4287, 41.4777, 41.5268, 41.5761, 41.6251, 41.6743,
41.7237, 41.7727, 41.8219, 41.8713, 41.9204, 41.9697, 42.0186, 42.0677,
43.4945, 43.5425, 43.5902, 43.6376, 43.6856, 43.7334, 43.7808, 43.8289,
43.8766, 43.9241, 43.9722, 44.02, 44.0675, 44.1157, 44.1635, 44.2111,
44.2583, 44.3062, 44.3538, 44.4011, 44.449, 44.4966, 44.544, 44.5919,
44.6396, 44.6869, 44.735, 44.7826, 44.8301, 44.8781, 44.9258, 44.9733,
45.0213, 45.0691, 45.1166, 45.1647, 45.2125, 45.26, 45.3081, 45.356,
45.4035, 45.4508, 45.4987, 45.5462, 45.5936, 45.6415, 45.6891, 45.7364,
45.7844, 45.832, 45.8794, 45.9274, 45.9751, 46.0225, 46.0705, 46.1183,
46.1657, 46.2138, 46.2615, 46.309, 46.3571, 46.4049, 46.4525, 46.5006,
43.4125, 43.4603, 43.5077, 43.5549, 43.6027, 43.6502, 43.6974, 43.7453,
43.7928, 43.84, 43.8879, 43.9355, 43.9828, 44.0307, 44.0783, 44.1256,
44.1726, 44.2203, 44.2676, 44.3147, 44.3624, 44.4098, 44.4569, 44.5046,
44.552, 44.5992, 44.6469, 44.6944, 44.7416, 44.7894, 44.8369, 44.8841,
44.9319, 44.9795, 45.0267, 45.0746, 45.1222, 45.1694, 45.2173, 45.265,
45.3123, 45.3593, 45.407, 45.4543, 45.5014, 45.5491, 45.5965, 45.6436,
45.6913, 45.7387, 45.7859, 45.8336, 45.8811, 45.9283, 45.9761, 46.0236,
46.0708, 46.1186, 46.1661, 46.2134, 46.2613, 46.3088, 46.3561, 46.404,
34.7729, 34.8035, 34.8341, 34.8648, 34.8953, 34.9259, 34.9567, 34.9872,
35.0179, 35.0486, 35.0792, 35.1098, 35.1404, 35.171, 35.2016, 35.2322,
35.2628, 35.2935, 35.324, 35.3547, 35.3854, 35.416, 35.4466, 35.4774,
35.508, 35.5387, 35.5692, 35.5998, 35.6305, 35.661, 35.6916, 35.7224,
35.7529, 35.7836, 35.8143, 35.8449, 35.8755, 35.9061, 35.9367, 35.9674,
35.9979, 36.0285, 36.0592, 36.0898, 36.1204, 36.1511, 36.1817, 36.2123,
36.2431, 36.2737, 36.3044, 36.3349, 36.3655, 36.3962, 36.4267, 36.4574,
36.4881, 36.5186, 36.5493, 36.58, 36.6106, 36.6413, 36.6718, 36.7024,
34.6995, 34.7297, 34.76, 34.7904, 34.8206, 34.8509, 34.8813, 34.9115,
34.9418, 34.9722, 35.0025, 35.0328, 35.063, 35.0933, 35.1237, 35.1539,
35.1842, 35.2146, 35.2448, 35.2751, 35.3055, 35.3357, 35.3661, 35.3965,
35.4268, 35.4571, 35.4873, 35.5176, 35.548, 35.5782, 35.6085, 35.6389,
35.6691, 35.6994, 35.7298, 35.7601, 35.7904, 35.8206, 35.8509, 35.8813,
35.9115, 35.9418, 35.9721, 36.0024, 36.0327, 36.0631, 36.0933, 36.1237,
36.1541, 36.1843, 36.2147, 36.2449, 36.2752, 36.3056, 36.3358, 36.3661,
36.3965, 36.4267, 36.457, 36.4874, 36.5177, 36.548, 36.5782, 36.6085,
37.0829, 37.1127, 37.1423, 37.1717, 37.2015, 37.2312, 37.2607, 37.2905,
37.3201, 37.3496, 37.3795, 37.4091, 37.4386, 37.4685, 37.4982, 37.5277,
37.5571, 37.5868, 37.6164, 37.6458, 37.6755, 37.7051, 37.7346, 37.7643,
37.7939, 37.8234, 37.8531, 37.8827, 37.9122, 37.942, 37.9716, 38.0011,
38.0309, 38.0606, 38.0901, 38.1199, 38.1496, 38.1791, 38.209, 38.2387,
38.2682, 38.2976, 38.3273, 38.3569, 38.3863, 38.416, 38.4456, 38.475,
38.5048, 38.5344, 38.5638, 38.5936, 38.6232, 38.6527, 38.6825, 38.7121,
38.7416, 38.7714, 38.8011, 38.8306, 38.8604, 38.8901, 38.9196, 38.9494,
37.0359, 37.0655, 37.095, 37.1243, 37.154, 37.1835, 37.2129, 37.2425,
37.2721, 37.3014, 37.3311, 37.3607, 37.39, 37.4198, 37.4493, 37.4787,
37.508, 37.5376, 37.567, 37.5963, 37.6259, 37.6553, 37.6846, 37.7142,
37.7437, 37.773, 37.8027, 37.8322, 37.8615, 37.8911, 37.9207, 37.95,
37.9797, 38.0092, 38.0386, 38.0683, 38.0978, 38.1272, 38.1569, 38.1865,
38.2159, 38.2451, 38.2747, 38.3042, 38.3334, 38.363, 38.3925, 38.4218,
38.4514, 38.4809, 38.5102, 38.5398, 38.5693, 38.5986, 38.6283, 38.6578,
38.6872, 38.7168, 38.7464, 38.7757, 38.8054, 38.835, 38.8644, 38.8941,
39.1594, 39.2093, 39.2593, 39.3095, 39.3594, 39.4094, 39.4597, 39.5096,
39.5597, 39.61, 39.6599, 39.7101, 39.7599, 39.8099, 39.8601, 39.91,
39.96, 40.0102, 40.0601, 40.1102, 40.1605, 40.2104, 40.2605, 40.3108,
40.3608, 40.411, 40.4608, 40.5108, 40.561, 40.6109, 40.661, 40.7112,
40.7611, 40.8112, 40.8615, 40.9115, 40.9616, 41.0114, 41.0614, 41.1116,
41.1615, 41.2115, 41.2617, 41.3116, 41.3617, 41.412, 41.4619, 41.512,
41.5624, 41.6123, 41.6625, 41.7123, 41.7623, 41.8126, 41.8624, 41.9125,
41.9627, 42.0127, 42.0628, 42.113, 42.163, 42.2131, 42.263, 42.313,
39.0297, 39.079, 39.1284, 39.1781, 39.2274, 39.2769, 39.3265, 39.3759,
39.4254, 39.4751, 39.5245, 39.5741, 39.6233, 39.6727, 39.7224, 39.7716,
39.8211, 39.8708, 39.9201, 39.9696, 40.0193, 40.0686, 40.1182, 40.1679,
40.2173, 40.2669, 40.3162, 40.3656, 40.4153, 40.4646, 40.514, 40.5637,
40.6131, 40.6626, 40.7123, 40.7617, 40.8112, 40.8605, 40.9099, 40.9595,
41.0088, 41.0583, 41.1079, 41.1573, 41.2068, 41.2565, 41.3058, 41.3554,
41.4051, 41.4545, 41.5041, 41.5534, 41.6028, 41.6524, 41.7017, 41.7512,
41.8009, 41.8502, 41.8998, 41.9495, 41.9988, 42.0484, 42.0977, 42.1471,
43.5891, 43.6374, 43.6854, 43.7331, 43.7814, 43.8294, 43.8772, 43.9255,
43.9736, 44.0214, 44.0698, 44.1179, 44.1657, 44.2141, 44.2623, 44.3101,
44.3577, 44.4058, 44.4537, 44.5013, 44.5495, 44.5974, 44.6451, 44.6933,
44.7413, 44.7889, 44.8372, 44.8852, 44.9329, 44.9812, 45.0293, 45.077,
45.1254, 45.1734, 45.2212, 45.2696, 45.3177, 45.3655, 45.414, 45.4621,
45.5099, 45.5575, 45.6057, 45.6535, 45.7011, 45.7493, 45.7973, 45.8449,
45.8931, 45.9411, 45.9888, 46.037, 46.085, 46.1327, 46.1811, 46.2291,
46.2768, 46.3252, 46.3733, 46.421, 46.4694, 46.5175, 46.5653, 46.6138,
43.5064, 43.5544, 43.6022, 43.6497, 43.6978, 43.7456, 43.7931, 43.8412,
43.889, 43.9366, 43.9847, 44.0326, 44.0802, 44.1284, 44.1763, 44.2239,
44.2712, 44.3191, 44.3668, 44.4141, 44.4621, 44.5098, 44.5572, 44.6052,
44.6529, 44.7004, 44.7484, 44.7962, 44.8436, 44.8918, 44.9395, 44.987,
45.0352, 45.083, 45.1305, 45.1787, 45.2266, 45.2742, 45.3223, 45.3703,
45.4179, 45.4652, 45.5131, 45.5608, 45.6081, 45.6561, 45.7038, 45.7512,
45.7992, 45.8469, 45.8944, 45.9424, 45.9902, 46.0376, 46.0857, 46.1335,
46.181, 46.2292, 46.277, 46.3245, 46.3727, 46.4206, 46.4682, 46.5164,
};
void TestAttentionMultipleTokens() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
AttentionMultipleTokensAttentionGoldens.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-3,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
void TestAttentionMultipleTokensAttentionWindowSizeEdgeCase() {
int qkv_dim = 64;
int kv_seq_len = 34;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 31; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
int attention_window_size = 32;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQs;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, attention_window_size,
num_kv_heads, num_heads, num_tokens, last_pos,
att_cap, layer_idx, layers_total, qbatch_size,
attention_impl);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
std::vector<float> att_out_golden_test_local = {
39.3051, 39.3556, 39.4062, 39.4571, 39.5075, 39.5582, 39.6091, 39.6596,
39.7103, 39.7612, 39.8118, 39.8626, 39.913, 39.9636, 40.0144, 40.0649,
40.1155, 40.1664, 40.2169, 40.2676, 40.3185, 40.369, 40.4198, 40.4707,
40.5213, 40.572, 40.6225, 40.6731, 40.724, 40.7744, 40.8251, 40.876,
40.9265, 40.9772, 41.0281, 41.0787, 41.1295, 41.1799, 41.2305, 41.2813,
41.3318, 41.3824, 41.4333, 41.4838, 41.5345, 41.5854, 41.6359, 41.6867,
41.7376, 41.7882, 41.839, 41.8894, 41.94, 41.9908, 42.0413, 42.092,
42.1429, 42.1934, 42.2441, 42.295, 42.3456, 42.3964, 42.4468, 42.4974,
39.1614, 39.2113, 39.2613, 39.3114, 39.3613, 39.4113, 39.4616, 39.5115,
39.5616, 39.6118, 39.6618, 39.7119, 39.7617, 39.8117, 39.8618, 39.9117,
39.9617, 40.0119, 40.0618, 40.1118, 40.1621, 40.212, 40.2621, 40.3124,
40.3623, 40.4125, 40.4623, 40.5123, 40.5625, 40.6123, 40.6624, 40.7126,
40.7625, 40.8126, 40.8629, 40.9128, 40.9629, 41.0127, 41.0627, 41.1129,
41.1627, 41.2127, 41.2629, 41.3128, 41.3629, 41.4131, 41.463, 41.5131,
41.5634, 41.6134, 41.6635, 41.7133, 41.7634, 41.8135, 41.8634, 41.9134,
41.9637, 42.0135, 42.0636, 42.1139, 42.1638, 42.214, 42.2637, 42.3137,
43.8459, 43.895, 43.9437, 43.9921, 44.0411, 44.0898, 44.1383, 44.1874,
44.2361, 44.2846, 44.3337, 44.3825, 44.4311, 44.4802, 44.529, 44.5776,
44.6258, 44.6747, 44.7233, 44.7716, 44.8205, 44.8692, 44.9175, 44.9665,
45.0151, 45.0635, 45.1125, 45.1612, 45.2096, 45.2586, 45.3074, 45.3558,
45.4049, 45.4537, 45.5021, 45.5513, 45.6001, 45.6486, 45.6977, 45.7466,
45.7951, 45.8434, 45.8923, 45.9409, 45.9891, 46.0381, 46.0867, 46.135,
46.184, 46.2327, 46.281, 46.33, 46.3787, 46.4271, 46.4762, 46.5249,
46.5733, 46.6224, 46.6712, 46.7197, 46.7688, 46.8176, 46.8661, 46.9153,
43.7538, 43.8026, 43.851, 43.8992, 43.948, 43.9964, 44.0446, 44.0934,
44.142, 44.1902, 44.239, 44.2876, 44.3358, 44.3847, 44.4333, 44.4816,
44.5296, 44.5782, 44.6266, 44.6746, 44.7232, 44.7716, 44.8197, 44.8684,
44.9168, 44.9649, 45.0136, 45.0621, 45.1102, 45.159, 45.2075, 45.2557,
45.3045, 45.353, 45.4012, 45.4501, 45.4986, 45.5469, 45.5958, 45.6444,
45.6927, 45.7406, 45.7893, 45.8376, 45.8856, 45.9343, 45.9827, 46.0307,
46.0794, 46.1278, 46.1759, 46.2247, 46.2731, 46.3213, 46.3701, 46.4185,
46.4667, 46.5155, 46.564, 46.6123, 46.6611, 46.7097, 46.7579, 46.8068,
48.7531, 48.8438, 48.9348, 49.0262, 49.1169, 49.208, 49.2995, 49.3903,
49.4815, 49.573, 49.6639, 49.7552, 49.8458, 49.9368, 50.0281, 50.1188,
50.2099, 50.3013, 50.3921, 50.4832, 50.5747, 50.6656, 50.7568, 50.8484,
50.9393, 51.0306, 51.1213, 51.2123, 51.3037, 51.3944, 51.4855, 51.577,
51.6678, 51.759, 51.8505, 51.9414, 52.0327, 52.1233, 52.2143, 52.3056,
52.3963, 52.4874, 52.5788, 52.6696, 52.7607, 52.8522, 52.9431, 53.0343,
53.1259, 53.2168, 53.3081, 53.3988, 53.4898, 53.5812, 53.6719, 53.763,
53.8545, 53.9453, 54.0365, 54.128, 54.2189, 54.3102, 54.4008, 54.4918,
48.4943, 48.5838, 48.6737, 48.7639, 48.8535, 48.9435, 49.0338, 49.1235,
49.2135, 49.3039, 49.3937, 49.4838, 49.5732, 49.6631, 49.7533, 49.8428,
49.9328, 50.023, 50.1127, 50.2027, 50.293, 50.3827, 50.4728, 50.5632,
50.653, 50.7432, 50.8327, 50.9226, 51.0128, 51.1024, 51.1924, 51.2827,
51.3724, 51.4624, 51.5528, 51.6425, 51.7327, 51.8221, 51.912, 52.0022,
52.0917, 52.1817, 52.2719, 52.3616, 52.4516, 52.5419, 52.6316, 52.7217,
52.8121, 52.9019, 52.9921, 53.0816, 53.1715, 53.2617, 53.3513, 53.4413,
53.5316, 53.6212, 53.7113, 53.8017, 53.8914, 53.9815, 54.071, 54.1609,
57.7208, 57.8084, 57.8954, 57.9818, 58.0694, 58.1564, 58.2429, 58.3306,
58.4177, 58.5043, 58.5921, 58.6793, 58.7659, 58.8537, 58.941, 59.0277,
59.1137, 59.2011, 59.2878, 59.374, 59.4614, 59.5482, 59.6345, 59.722,
59.8089, 59.8952, 59.9827, 60.0697, 60.1561, 60.2437, 60.3308, 60.4172,
60.505, 60.5921, 60.6786, 60.7664, 60.8536, 60.9402, 61.0281, 61.1153,
61.202, 61.2881, 61.3755, 61.4622, 61.5483, 61.6358, 61.7226, 61.8088,
61.8963, 61.9832, 62.0695, 62.1571, 62.244, 62.3304, 62.4181, 62.5051,
62.5916, 62.6793, 62.7664, 62.853, 62.9407, 63.0279, 63.1146, 63.2024,
57.5554, 57.6426, 57.729, 57.815, 57.9021, 57.9887, 58.0747, 58.162,
58.2486, 58.3347, 58.422, 58.5087, 58.5949, 58.6823, 58.7691, 58.8553,
58.9409, 59.0278, 59.114, 59.1997, 59.2867, 59.373, 59.4588, 59.5458,
59.6323, 59.7181, 59.8052, 59.8917, 59.9776, 60.0648, 60.1514, 60.2374,
60.3246, 60.4113, 60.4974, 60.5847, 60.6714, 60.7576, 60.8449, 60.9317,
61.018, 61.1036, 61.1905, 61.2767, 61.3624, 61.4494, 61.5357, 61.6215,
61.7085, 61.7949, 61.8808, 61.9679, 62.0544, 62.1403, 62.2275, 62.3141,
62.4001, 62.4873, 62.574, 62.66, 62.7474, 62.8341, 62.9202, 63.0076,
39.3678, 39.4186, 39.4696, 39.5207, 39.5715, 39.6225, 39.6737, 39.7246,
39.7756, 39.8268, 39.8777, 39.9288, 39.9796, 40.0305, 40.0816, 40.1324,
40.1834, 40.2346, 40.2854, 40.3364, 40.3876, 40.4385, 40.4896, 40.5408,
40.5917, 40.6428, 40.6936, 40.7446, 40.7957, 40.8466, 40.8975, 40.9487,
40.9996, 41.0506, 41.1019, 41.1528, 41.2038, 41.2546, 41.3055, 41.3567,
41.4075, 41.4584, 41.5096, 41.5605, 41.6115, 41.6627, 41.7136, 41.7646,
41.8159, 41.8668, 41.9179, 41.9687, 42.0196, 42.0708, 42.1216, 42.1726,
42.2238, 42.2746, 42.3256, 42.3769, 42.4278, 42.4789, 42.5296, 42.5806,
39.2228, 39.2729, 39.3232, 39.3737, 39.4239, 39.4743, 39.5248, 39.575,
39.6254, 39.676, 39.7263, 39.7767, 39.8268, 39.8771, 39.9276, 39.9778,
40.0281, 40.0786, 40.1288, 40.1792, 40.2298, 40.28, 40.3304, 40.381,
40.4313, 40.4818, 40.5319, 40.5822, 40.6327, 40.6829, 40.7333, 40.7838,
40.834, 40.8844, 40.935, 40.9853, 41.0357, 41.0858, 41.1361, 41.1866,
41.2368, 41.2871, 41.3376, 41.3878, 41.4382, 41.4888, 41.539, 41.5894,
41.64, 41.6903, 41.7408, 41.7909, 41.8412, 41.8917, 41.9419, 41.9922,
42.0428, 42.093, 42.1434, 42.194, 42.2442, 42.2947, 42.3448, 42.3951,
43.9435, 43.9928, 44.0418, 44.0905, 44.1399, 44.1889, 44.2376, 44.287,
44.3361, 44.3849, 44.4343, 44.4834, 44.5322, 44.5817, 44.6308, 44.6797,
44.7283, 44.7774, 44.8263, 44.8749, 44.9241, 44.9731, 45.0217, 45.071,
45.12, 45.1686, 45.2179, 45.2669, 45.3156, 45.365, 45.414, 45.4628,
45.5122, 45.5613, 45.61, 45.6595, 45.7086, 45.7574, 45.8068, 45.856,
45.9048, 45.9534, 46.0026, 46.0515, 46.1001, 46.1493, 46.1982, 46.2469,
46.2961, 46.3451, 46.3938, 46.4431, 46.4921, 46.5408, 46.5901, 46.6392,
46.6879, 46.7373, 46.7864, 46.8352, 46.8846, 46.9337, 46.9825, 47.032,
43.8506, 43.8996, 43.9484, 43.9968, 44.0459, 44.0947, 44.1432, 44.1923,
44.2411, 44.2896, 44.3388, 44.3876, 44.4362, 44.4854, 44.5343, 44.5829,
44.6312, 44.6801, 44.7287, 44.7771, 44.826, 44.8747, 44.9231, 44.9721,
45.0208, 45.0692, 45.1182, 45.167, 45.2154, 45.2645, 45.3133, 45.3617,
45.4109, 45.4597, 45.5082, 45.5574, 45.6062, 45.6548, 45.704, 45.7529,
45.8015, 45.8498, 45.8987, 45.9473, 45.9957, 46.0446, 46.0933, 46.1416,
46.1906, 46.2394, 46.2878, 46.3368, 46.3856, 46.434, 46.4831, 46.5319,
46.5803, 46.6295, 46.6783, 46.7268, 46.776, 46.8248, 46.8734, 46.9226,
48.8777, 48.969, 49.0607, 49.1527, 49.2441, 49.3358, 49.4279, 49.5194,
49.6112, 49.7034, 49.7949, 49.8868, 49.9781, 50.0697, 50.1617, 50.2531,
50.3448, 50.4368, 50.5283, 50.62, 50.7122, 50.8037, 50.8956, 50.9878,
51.0794, 51.1713, 51.2626, 51.3543, 51.4463, 51.5377, 51.6294, 51.7215,
51.813, 51.9048, 51.997, 52.0885, 52.1805, 52.2717, 52.3633, 52.4553,
52.5467, 52.6384, 52.7305, 52.8219, 52.9137, 53.0058, 53.0973, 53.1892,
53.2814, 53.373, 53.4649, 53.5562, 53.6479, 53.7399, 53.8313, 53.923,
54.0152, 54.1066, 54.1984, 54.2906, 54.3821, 54.4741, 54.5653, 54.6569,
48.6164, 48.7066, 48.7971, 48.888, 48.9782, 49.0688, 49.1597, 49.25,
49.3407, 49.4317, 49.5221, 49.6129, 49.703, 49.7934, 49.8843, 49.9745,
50.065, 50.1559, 50.2462, 50.3368, 50.4278, 50.5181, 50.6089, 50.6999,
50.7903, 50.8811, 50.9713, 51.0618, 51.1527, 51.2429, 51.3335, 51.4244,
51.5147, 51.6054, 51.6964, 51.7868, 51.8776, 51.9677, 52.0581, 52.149,
52.2392, 52.3297, 52.4206, 52.5109, 52.6015, 52.6925, 52.7828, 52.8736,
52.9646, 53.055, 53.1458, 53.236, 53.3265, 53.4174, 53.5076, 53.5982,
53.6891, 53.7794, 53.8701, 53.9611, 54.0515, 54.1423, 54.2324, 54.3228,
57.914, 58.0021, 58.0897, 58.1767, 58.265, 58.3526, 58.4397, 58.528,
58.6157, 58.7028, 58.7912, 58.879, 58.9662, 59.0547, 59.1426, 59.2299,
59.3165, 59.4045, 59.4918, 59.5786, 59.6666, 59.754, 59.8408, 59.9289,
60.0165, 60.1033, 60.1915, 60.2791, 60.3661, 60.4544, 60.542, 60.629,
60.7174, 60.8051, 60.8922, 60.9806, 61.0684, 61.1556, 61.2441, 61.332,
61.4193, 61.5059, 61.5939, 61.6812, 61.768, 61.856, 61.9434, 62.0302,
62.1183, 62.2059, 62.2927, 62.3809, 62.4685, 62.5555, 62.6437, 62.7314,
62.8184, 62.9068, 62.9945, 63.0816, 63.17, 63.2578, 63.345, 63.4335,
57.7471, 57.8348, 57.9219, 58.0084, 58.0962, 58.1834, 58.27, 58.3578,
58.4451, 58.5317, 58.6197, 58.707, 58.7937, 58.8817, 58.9691, 59.0559,
59.1421, 59.2296, 59.3165, 59.4028, 59.4903, 59.5773, 59.6636, 59.7512,
59.8383, 59.9247, 60.0124, 60.0995, 60.186, 60.2738, 60.361, 60.4476,
60.5354, 60.6227, 60.7093, 60.7973, 60.8846, 60.9713, 61.0593, 61.1467,
61.2335, 61.3197, 61.4072, 61.4941, 61.5804, 61.6679, 61.7549, 61.8412,
61.9289, 62.0159, 62.1023, 62.19, 62.2772, 62.3636, 62.4514, 62.5386,
62.6252, 62.7131, 62.8003, 62.887, 62.9749, 63.0622, 63.1489, 63.237};
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
att_out_golden_test_local.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-3,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
void TestAttentionMultipleTokensBF16() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
AttentionMultipleTokensAttentionGoldens.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-1,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
void TestAttentionMultipleTokensInt8() {
int qkv_dim = 64;
int kv_seq_len = 64;
int num_kv_heads = 2;
int num_heads = 4;
int num_tokens = 2;
int last_pos = 62; // so in the tbatch token 0 will have 63 and token 1
// will have 64 tokens to attend to.
float att_cap = 10.0f;
int layer_idx = 0;
int layers_total = 1;
int qbatch_size = 2;
AttentionImpl attention_impl = AttentionImpl::kFlashTransposedQsBF16;
AttentionTestEnv test_env(qkv_dim, kv_seq_len, kv_seq_len, num_kv_heads,
num_heads, num_tokens, last_pos, att_cap, layer_idx,
layers_total, qbatch_size, attention_impl,
Type::kInt8);
test_env.SetupWeights();
FillMatPtrT(test_env.activations->attention.pre_att_rms_out);
FillMatPtrT(test_env.activations->attention.q);
FillMatPtrT(test_env.activations->attention.att);
FillMatPtrT(test_env.activations->attention.att_out);
FillMatPtrT(test_env.activations->attention.softmax_max);
FillMatPtrT(test_env.activations->attention.softmax_d);
int flags = AttentionImplToFlags(attention_impl, HWY_NATIVE_DOT_BF16);
TiledAttention(attention_impl, num_tokens, layer_idx, *test_env.layer,
test_env.activations->attention, *test_env.qbatch,
test_env.env, flags);
std::cerr << "att_out\n";
PrintMatPtr(test_env.activations->attention.att_out);
for (size_t i = 0; i < test_env.activations->attention.att_out.Rows(); ++i) {
EXPECT_TRUE(hwy::CompareArraySimilar(
AttentionMultipleTokensAttentionGoldens.data() +
i * test_env.activations->attention.att_out.Cols(),
test_env.activations->attention.att_out.Row(i),
test_env.activations->attention.att_out.Cols(), 1e-1,
hwy::TargetName(HWY_TARGET), __FILE__, __LINE__))
<< "att_out mismatch for query: " << i;
}
}
bool transposed =
attention_impl == AttentionImpl::kFlashTransposedQsBF16
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE