mirror of https://github.com/google/gemma.cpp.git
Add the missing `HWY_ATTR` of `ProjKV`
This commit is contained in:
parent
ce32f4db81
commit
c75d2eb635
17
gemma.cc
17
gemma.cc
|
|
@ -320,15 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
|
||||||
auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) {
|
auto ProjKV =
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
[&](size_t k_offset, size_t v_offset, size_t kv_offset) HWY_ATTR {
|
||||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||||
activations.pre_att_rms_out.data() + batch_offset,
|
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||||
kv_cache.key_cache.get() + kv_offset,
|
activations.pre_att_rms_out.data() + batch_offset,
|
||||||
kv_cache.value_cache.get() + kv_offset);
|
kv_cache.key_cache.get() + kv_offset,
|
||||||
|
kv_cache.value_cache.get() + kv_offset);
|
||||||
|
|
||||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||||
};
|
};
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
// linear projections to QKV
|
// linear projections to QKV
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue