mirror of https://github.com/google/gemma.cpp.git
Move conditional branch out of `pos2` loop
This commit is contained in:
parent
c75d2eb635
commit
8fc6959950
11
gemma.cc
11
gemma.cc
|
|
@ -373,12 +373,13 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
|
|
||||||
Rope(q, kQKVDim, pos);
|
Rope(q, kQKVDim, pos);
|
||||||
MulByConst(kQueryScale, q, kQKVDim);
|
MulByConst(kQueryScale, q, kQKVDim);
|
||||||
|
|
||||||
|
const size_t head_offset = kHeads == kKVHeads ? head * kQKVDim : 0;
|
||||||
|
|
||||||
// Compute Q dot K scores
|
// Compute Q dot K scores
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
kHeads == kKVHeads
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
|
||||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
|
||||||
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
const float* HWY_RESTRICT k2 = kv_cache.key_cache.get() + cache_offset;
|
||||||
const float score = Dot(q, k2, kQKVDim);
|
const float score = Dot(q, k2, kQKVDim);
|
||||||
head_att[pos2] = score;
|
head_att[pos2] = score;
|
||||||
|
|
@ -391,9 +392,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t batch_idx, size_t layer,
|
||||||
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
hwy::ZeroBytes(att_out, kQKVDim * sizeof(*att_out));
|
||||||
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
for (size_t pos2 = 0; pos2 <= pos; ++pos2) {
|
||||||
const size_t cache_offset =
|
const size_t cache_offset =
|
||||||
kHeads == kKVHeads
|
pos2 * kCachePosSize + layer * kCacheLayerSize + head_offset;
|
||||||
? pos2 * kCachePosSize + layer * kCacheLayerSize + head * kQKVDim
|
|
||||||
: pos2 * kCachePosSize + layer * kCacheLayerSize;
|
|
||||||
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
float* HWY_RESTRICT v2 = kv_cache.value_cache.get() + cache_offset;
|
||||||
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
MulByConstAndAdd(head_att[pos2], v2, att_out, kQKVDim);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue