mirror of https://github.com/google/gemma.cpp.git
Merge pull request #175 from szabadka:gemma2
PiperOrigin-RevId: 630044058
This commit is contained in:
commit
bafb8382f8
|
|
@ -741,22 +741,6 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
float* HWY_RESTRICT v2 = kv_cache.kv_cache.get() + cache_offset + kQKVDim;
|
||||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
// linear projection from kQKVDim back to kModelDim, sum projections
|
|
||||||
// across heads
|
|
||||||
float* HWY_RESTRICT head_out =
|
|
||||||
head == 0
|
|
||||||
? activations.att_post2.data() + batch_idx * kModelDim
|
|
||||||
: activations.att_post1.data() + head * kBatchSize * kModelDim;
|
|
||||||
float* even_odd = activations.even_odd.data() + thread * kQKVDim;
|
|
||||||
if (head == 0) {
|
|
||||||
MatVecAddLoop<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
|
||||||
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim, att_out,
|
|
||||||
layer_weights->attention_output_biases.data(), even_odd, head_out);
|
|
||||||
} else {
|
|
||||||
MatVecLoop<kModelDim, kQKVDim>(layer_weights->attn_vec_einsum_w,
|
|
||||||
head * kModelDim * kQKVDim, att_out,
|
|
||||||
even_odd, head_out);
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
if constexpr (kHeads == kKVHeads) {
|
if constexpr (kHeads == kKVHeads) {
|
||||||
|
|
@ -810,11 +794,24 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// accumulate output across all heads into att_post2. head 0 already wrote
|
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
|
||||||
// directly to att_post2.
|
// rearranging the weights.
|
||||||
for (size_t head = 1; head < kHeads; ++head) {
|
float* HWY_RESTRICT att_out =
|
||||||
AddFrom(activations.att_post1.data() + head * kBatchSize * kModelDim,
|
activations.att_out.data() + batch_idx * kHeads * kQKVDim;
|
||||||
activations.att_post2.data() + batch_idx * kModelDim, kModelDim);
|
float* HWY_RESTRICT layer_out =
|
||||||
|
activations.att_post2.data() + batch_idx * kModelDim;
|
||||||
|
MatVecAdd<TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||||
|
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||||
|
layer_weights->attention_output_biases.data(),
|
||||||
|
activations.even_odd.data(), layer_out, pool);
|
||||||
|
for (size_t head = 1; head < kHeads; ++head) {
|
||||||
|
float* HWY_RESTRICT head_out =
|
||||||
|
activations.att_post1.data() + head * kBatchSize * kModelDim;
|
||||||
|
MatVec<kModelDim, kQKVDim>(
|
||||||
|
layer_weights->attn_vec_einsum_w, head * kModelDim * kQKVDim,
|
||||||
|
att_out + head * kQKVDim,
|
||||||
|
activations.even_odd.data(), head_out, pool);
|
||||||
|
AddFrom(head_out, layer_out, kModelDim);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue