mirror of https://github.com/google/gemma.cpp.git
parent
49d420aeaf
commit
5a500872b8
|
|
@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
|
|||
struct AttentionActivations {
|
||||
AttentionActivations(
|
||||
const ModelConfig& config, const LayerConfig& layer_config,
|
||||
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
||||
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
|
||||
const Allocator& allocator,
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||
// MHA and does not use an external KV cache.
|
||||
|
|
@ -217,7 +218,8 @@ struct Activations {
|
|||
|
||||
attention_impl(runtime_config.attention_impl),
|
||||
attention_storage(config, layer_config, batch_size, seq_len,
|
||||
ctx.allocator, row_ptrs),
|
||||
runtime_config.attention_impl, ctx.allocator,
|
||||
row_ptrs),
|
||||
attention(config, seq_len, attention_storage) {
|
||||
HWY_ASSERT(batch_size != 0);
|
||||
|
||||
|
|
|
|||
|
|
@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
|||
const size_t batch_size = kOuter;
|
||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||
kOuter, ctx.allocator, row_ptrs);
|
||||
kOuter, AttentionImpl::kFlash,
|
||||
ctx.allocator, row_ptrs);
|
||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||
const size_t qkv_dim = layer_config.qkv_dim;
|
||||
ASSERT_EQ(qkv_dim, kInner);
|
||||
|
|
|
|||
Loading…
Reference in New Issue