Add the missing `HWY_ATTR` of `ProjKV`

This commit is contained in:
RangerUFO 2024-03-20 23:21:43 +08:00
parent ce32f4db81
commit c75d2eb635
1 changed files with 9 additions and 8 deletions

View File

@ -320,15 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
const size_t batch_offset = batch_idx * kModelDim;
auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) {
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
auto ProjKV =
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
};
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
};
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV