mirror of https://github.com/google/gemma.cpp.git
Merge pull request #175 from szabadka:gemma2
PiperOrigin-RevId: 630044058
This commit is contained in:
commit
bafb8382f8
|
|
@ -741,22 +741,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
// linear projection from kQKVDim back to kModelDim, sum projections
|
||||
// across heads
|
||||
float* HWY_RESTRICT head_out =
|
||||
head == 0
|
||||
? activations.att_post2.data() + batch_idx * kModelDim
|
||||
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
float* even_odd = activations.even_odd.data() + thread * kQKVDim;
|
||||
if (head == 0) {
|
||||
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
|
||||
layer_weights->attention_output_biases.data(), even_odd, head_out);
|
||||
} else {
|
||||
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
||||
head * kModelDim * kQKVDim, att_out,
|
||||
even_odd, head_out);
|
||||
}
|
||||
};
|
||||
|
||||
if constexpr (kHeads == kKVHeads) {
|
||||
|
|
@ -810,11 +794,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
});
|
||||
}
|
||||
|
||||
// accumulate output across all heads into att_post2. head 0 already wrote
|
||||
// directly to att_post2.
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim,
|
||||
activations.att_post2.data() + batch_idx * kModelDim, kModelDim);
|
||||
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||
// rearranging the weights.
|
||||
float* HWY_RESTRICT att_out =
|
||||
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
|
||||
float* HWY_RESTRICT layer_out =
|
||||
activations.att_post2.data() + batch_idx * kModelDim;
|
||||
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||
layer_weights->attention_output_biases.data(),
|
||||
activations.even_odd.data(), layer_out, pool);
|
||||
for (size_t head = 1; head < kHeads; ++head) {
|
||||
float* HWY_RESTRICT head_out =
|
||||
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||
MatVec<kModelDim, kQKVDim>(
|
||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||
att_out + head * kQKVDim,
|
||||
activations.even_odd.data(), head_out, pool);
|
||||
AddFrom(head_out, layer_out, kModelDim);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue