mirror of https://github.com/google/gemma.cpp.git
Add the missing `HWY_ATTR` of `ProjKV`
This commit is contained in:
parent
ce32f4db81
commit
c75d2eb635
17
gemma.cc
17
gemma.cc
|
|
@ -320,15 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
|
||||
auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) {
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
kv_cache.key_cache.get() + kv_offset,
|
||||
kv_cache.value_cache.get() + kv_offset);
|
||||
auto ProjKV =
|
||||
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
kv_cache.key_cache.get() + kv_offset,
|
||||
kv_cache.value_cache.get() + kv_offset);
|
||||
|
||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||
};
|
||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||
};
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
|
|
|
|||
Loading…
Reference in New Issue