Merge pull request #175 from szabadka:gemma2

PiperOrigin-RevId: 630044058
This commit is contained in:
Copybara-Service 2024-05-02 06:27:15 -07:00
commit bafb8382f8
1 changed files with 18 additions and 21 deletions

View File

@ -741,22 +741,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}
// linear projection from kQKVDim back to kModelDim, sum projections
// across heads
float* HWY_RESTRICT head_out =
head == 0
? activations.att_post2.data() + batch_idx * kModelDim
: activations.att_post1.data() + head * kBatchSize * kModelDim;
float* even_odd = activations.even_odd.data() + thread * kQKVDim;
if (head == 0) {
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
layer_weights->attention_output_biases.data(), even_odd, head_out);
} else {
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
head * kModelDim * kQKVDim, att_out,
even_odd, head_out);
}
};
if constexpr (kHeads == kKVHeads) {
@ -810,11 +794,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
});
}
// accumulate output across all heads into att_post2. head 0 already wrote
// directly to att_post2.
for (size_t head = 1; head < kHeads; ++head) {
AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim,
activations.att_post2.data() + batch_idx * kModelDim, kModelDim);
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out =
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
float* HWY_RESTRICT layer_out =
activations.att_post2.data() + batch_idx * kModelDim;
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(),
activations.even_odd.data(), layer_out, pool);
for (size_t head = 1; head < kHeads; ++head) {
float* HWY_RESTRICT head_out =
activations.att_post1.data() + head * kBatchSize * kModelDim;
MatVec<kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim,
activations.even_odd.data(), head_out, pool);
AddFrom(head_out, layer_out, kModelDim);
}
}