Move conditional branch out of `pos2` loop

This commit is contained in:
RangerUFO 2024-03-20 23:50:14 +08:00
parent c75d2eb635
commit 8fc6959950
1 changed files with 5 additions and 6 deletions

View File

@ -373,12 +373,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
Rope(q, kQKVDim, pos);
MulByConst(kQueryScale, q, kQKVDim);
const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0;
// Compute Q dot K scores
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset =
kHeads == kKVHeads
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
: pos2 * kCachePosSize + layer * kCacheLayerSize;
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
const float score = Dot(q, k2, kQKVDim);
head_att[pos2] = score;
@ -391,9 +392,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
const size_t cache_offset =
kHeads == kKVHeads
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
: pos2 * kCachePosSize + layer * kCacheLayerSize;
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
}