mirror of https://github.com/google/gemma.cpp.git
Streamline the implementation
This commit is contained in:
parent
6923aec853
commit
ce32f4db81
32
gemma.cc
32
gemma.cc
|
|
@ -320,6 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
const size_t batch_offset = batch_idx * kModelDim;
|
const size_t batch_offset = batch_idx * kModelDim;
|
||||||
|
|
||||||
|
auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) {
|
||||||
|
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
||||||
|
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
||||||
|
activations.pre_att_rms_out.data() + batch_offset,
|
||||||
|
kv_cache.key_cache.get() + kv_offset,
|
||||||
|
kv_cache.value_cache.get() + kv_offset);
|
||||||
|
|
||||||
|
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
||||||
|
};
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
// linear projections to QKV
|
// linear projections to QKV
|
||||||
constexpr const size_t head_offset =
|
constexpr const size_t head_offset =
|
||||||
|
|
@ -339,13 +349,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
const size_t kv_offset =
|
const size_t kv_offset =
|
||||||
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
|
||||||
|
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
|
||||||
activations.pre_att_rms_out.data() + batch_offset,
|
|
||||||
kv_cache.key_cache.get() + kv_offset,
|
|
||||||
kv_cache.value_cache.get() + kv_offset);
|
|
||||||
|
|
||||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
|
|
@ -355,13 +359,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
|
||||||
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
|
||||||
|
|
||||||
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
|
ProjKV(k_offset, v_offset, kv_offset);
|
||||||
c_layer->c_qkv_einsum_w, k_offset, v_offset,
|
|
||||||
activations.pre_att_rms_out.data() + batch_offset,
|
|
||||||
kv_cache.key_cache.get() + kv_offset,
|
|
||||||
kv_cache.value_cache.get() + kv_offset);
|
|
||||||
|
|
||||||
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
|
||||||
|
|
@ -376,7 +374,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset = kHeads == kKVHeads
|
const size_t cache_offset =
|
||||||
|
kHeads == kKVHeads
|
||||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||||
|
|
@ -390,7 +389,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
batch_idx * kHeads * kQKVDim;
|
batch_idx * kHeads * kQKVDim;
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset = kHeads == kKVHeads
|
const size_t cache_offset =
|
||||||
|
kHeads == kKVHeads
|
||||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue