mirror of https://github.com/google/gemma.cpp.git
Move conditional branch out of `pos2` loop
This commit is contained in:
parent
c75d2eb635
commit
8fc6959950
11
gemma.cc
11
gemma.cc
|
|
@ -373,12 +373,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
|
||||
Rope(q, kQKVDim, pos);
|
||||
MulByConst(kQueryScale, q, kQKVDim);
|
||||
|
||||
const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0;
|
||||
|
||||
// Compute Q dot K scores
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
kHeads == kKVHeads
|
||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||
const float score = Dot(q, k2, kQKVDim);
|
||||
head_att[pos2] = score;
|
||||
|
|
@ -391,9 +392,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
|||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||
const size_t cache_offset =
|
||||
kHeads == kKVHeads
|
||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
||||
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue