mirror of https://github.com/google/gemma.cpp.git
Streamline the implementation
This commit is contained in:
parent
6923aec853
commit
ce32f4db81
32
gemma.cc
32
gemma.cc
|
|
@ -320,6 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
const size_t batch_offset = batch_idx * kModelDim;
|
||||
|
||||
auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) {
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
kv_cache.key_cache.get() + kv_offset,
|
||||
kv_cache.value_cache.get() + kv_offset);
|
||||
|
||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||
};
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
// linear projections to QKV
|
||||
constexpr const size_t head_offset =
|
||||
|
|
@ -339,13 +349,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
const size_t kv_offset =
|
||||
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
kv_cache.key_cache.get() + kv_offset,
|
||||
kv_cache.value_cache.get() + kv_offset);
|
||||
|
||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||
ProjKV(k_offset, v_offset, kv_offset);
|
||||
}
|
||||
});
|
||||
|
||||
|
|
@ -355,13 +359,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
|
||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||
activations.pre_att_rms_out.data() + batch_offset,
|
||||
kv_cache.key_cache.get() + kv_offset,
|
||||
kv_cache.value_cache.get() + kv_offset);
|
||||
|
||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||
ProjKV(k_offset, v_offset, kv_offset);
|
||||
}
|
||||
|
||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||
|
|
@ -376,7 +374,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
// Compute Q dot K scores
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_offset = kHeads == kKVHeads
|
||||
const size_t cache_offset =
|
||||
kHeads == kKVHeads
|
||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||
|
|
@ -390,7 +389,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
batch_idx * kHeads * kQKVDim;
|
||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_offset = kHeads == kKVHeads
|
||||
const size_t cache_offset =
|
||||
kHeads == kKVHeads
|
||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||
|
|
|
|||
Loading…
Reference in New Issue