mirror of https://github.com/google/gemma.cpp.git
parent
49d420aeaf
commit
5a500872b8
|
|
@ -46,7 +46,8 @@ static inline float ChooseQueryScale(const ModelConfig& config) {
|
||||||
struct AttentionActivations {
|
struct AttentionActivations {
|
||||||
AttentionActivations(
|
AttentionActivations(
|
||||||
const ModelConfig& config, const LayerConfig& layer_config,
|
const ModelConfig& config, const LayerConfig& layer_config,
|
||||||
size_t batch_size, size_t seq_len, const Allocator& allocator,
|
size_t batch_size, size_t seq_len, AttentionImpl attention_impl,
|
||||||
|
const Allocator& allocator,
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>>& row_ptrs)
|
||||||
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
: // `vocab_size == 0` means it is for Vit part, VitAttention is still
|
||||||
// MHA and does not use an external KV cache.
|
// MHA and does not use an external KV cache.
|
||||||
|
|
@ -217,7 +218,8 @@ struct Activations {
|
||||||
|
|
||||||
attention_impl(runtime_config.attention_impl),
|
attention_impl(runtime_config.attention_impl),
|
||||||
attention_storage(config, layer_config, batch_size, seq_len,
|
attention_storage(config, layer_config, batch_size, seq_len,
|
||||||
ctx.allocator, row_ptrs),
|
runtime_config.attention_impl, ctx.allocator,
|
||||||
|
row_ptrs),
|
||||||
attention(config, seq_len, attention_storage) {
|
attention(config, seq_len, attention_storage) {
|
||||||
HWY_ASSERT(batch_size != 0);
|
HWY_ASSERT(batch_size != 0);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,7 +124,8 @@ void TestFlashAttention(size_t target_parallelism) {
|
||||||
const size_t batch_size = kOuter;
|
const size_t batch_size = kOuter;
|
||||||
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
std::vector<hwy::AlignedFreeUniquePtr<uint8_t*[]>> row_ptrs;
|
||||||
AttentionActivations attention_storage(config, layer_config, batch_size,
|
AttentionActivations attention_storage(config, layer_config, batch_size,
|
||||||
kOuter, ctx.allocator, row_ptrs);
|
kOuter, AttentionImpl::kFlash,
|
||||||
|
ctx.allocator, row_ptrs);
|
||||||
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
AttentionActivationsPtrs attention(config, kOuter, attention_storage);
|
||||||
const size_t qkv_dim = layer_config.qkv_dim;
|
const size_t qkv_dim = layer_config.qkv_dim;
|
||||||
ASSERT_EQ(qkv_dim, kInner);
|
ASSERT_EQ(qkv_dim, kInner);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue