Internal change

PiperOrigin-RevId: 835115693
This commit is contained in:
Martin Stolle 2025-11-21 01:17:06 -08:00 committed by Copybara-Service
parent 49d420aeaf
commit 5a500872b8
2 changed files with 6 additions and 3 deletions

View File

@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
struct AttentionActivations {
AttentionActivations(
const ModelConfig& config, const LayerConfig& layer_config,
size_t batch_size, size_t seq_len, const Allocator& allocator,
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
const Allocator& allocator,
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
// MHA and does not use an external KV cache.
@ -217,7 +218,8 @@ struct Activations {
attention_impl(runtime_config.attention_impl),
attention_storage(config, layer_config, batch_size, seq_len,
ctx.allocator, row_ptrs),
runtime_config.attention_impl, ctx.allocator,
row_ptrs),
attention(config, seq_len, attention_storage) {
HWY_ASSERT(batch_size != 0);

View File

@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
const size_t batch_size = kOuter;
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
AttentionActivations attention_storage(config, layer_config, batch_size,
kOuter, ctx.allocator, row_ptrs);
kOuter, AttentionImpl::kFlash,
ctx.allocator, row_ptrs);
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
const size_t qkv_dim = layer_config.qkv_dim;
ASSERT_EQ(qkv_dim, kInner);