Streamline the implementation

This commit is contained in:
RangerUFO 2024-03-20 22:39:31 +08:00
parent 6923aec853
commit ce32f4db81
1 changed files with 20 additions and 20 deletions

View File

@ -320,6 +320,16 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
const size_t batch_offset = batch_idx * kModelDim;
auto ProjKV = [&](size_t k_offset, size_t v_offset, size_t kv_offset) {
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
};
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
// linear projections to QKV
constexpr const size_t head_offset =
@ -339,13 +349,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
const size_t kv_offset =
pos * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim;
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
ProjKV(k_offset, v_offset, kv_offset);
}
});
@ -355,13 +359,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
constexpr const size_t v_offset = q_offset + 1 * kQKVDim * kModelDim;
const size_t kv_offset = pos * kCachePosSize + layer * kCacheLayerSize;
TwoOfsMatVecLoop<kQKVDim, kModelDim>(
c_layer->c_qkv_einsum_w, k_offset, v_offset,
activations.pre_att_rms_out.data() + batch_offset,
kv_cache.key_cache.get() + kv_offset,
kv_cache.value_cache.get() + kv_offset);
Rope(kv_cache.key_cache.get() + kv_offset, kQKVDim, pos);
ProjKV(k_offset, v_offset, kv_offset);
}
pool.Run(0, kHeads, [&](const uint64_t head, size_t /*thread*/) HWY_ATTR {
@ -376,7 +374,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
MulByConst(kQueryScale, q, kQKVDim);
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset = kHeads == kKVHeads
const size_t cache_offset =
kHeads == kKVHeads
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
: pos2 * kCachePosSize + layer * kCacheLayerSize;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
@ -390,7 +389,8 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
batch_idx * kHeads * kQKVDim;
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset = kHeads == kKVHeads
const size_t cache_offset =
kHeads == kKVHeads
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
: pos2 * kCachePosSize + layer * kCacheLayerSize;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;