mirror of https://github.com/google/gemma.cpp.git
parent
6721dddf38
commit
f9f2d909ed
|
|
@ -71,12 +71,16 @@ void TransposeKVCacheRow(const KV_t* HWY_RESTRICT kv, KV_t* HWY_RESTRICT k,
|
|||
// is a tiny fraction of the overall computation, and it is linear in the
|
||||
// token length.
|
||||
const size_t kFloatsPerTile = 2 * FloatsPerVector();
|
||||
for (size_t i = 0; i < qkv_dim; i += 2) {
|
||||
for (size_t i = 0; i + 1 < qkv_dim; i += 2) {
|
||||
k[i * kFloatsPerTile] = kv[i];
|
||||
k[i * kFloatsPerTile + 1] = kv[i + 1];
|
||||
}
|
||||
if (qkv_dim % 2 == 1) {
|
||||
const size_t i = qkv_dim - 1;
|
||||
k[i * kFloatsPerTile] = kv[i];
|
||||
}
|
||||
for (size_t i = 0; i < qkv_dim; i += kFloatsPerTile) {
|
||||
for (size_t j = 0; j < kFloatsPerTile; j++) {
|
||||
for (size_t j = 0; j < kFloatsPerTile && i + j < qkv_dim; j++) {
|
||||
v[i * kFloatsPerTile + j] = kv[i + j + qkv_dim];
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -702,7 +702,7 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddVT4Mem(
|
|||
i += NF * 2;
|
||||
v_bf += 4 * NF * NF;
|
||||
}
|
||||
HWY_DASSERT(size == i);
|
||||
// HWY_DASSERT(size == i);
|
||||
}
|
||||
|
||||
template <class DF, class VF = hn::Vec<DF>>
|
||||
|
|
@ -1026,7 +1026,7 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8(
|
|||
|
||||
i += 2 * NF;
|
||||
}
|
||||
HWY_DASSERT(qkv_dim == i);
|
||||
// HWY_DASSERT(qkv_dim == i);
|
||||
}
|
||||
|
||||
template <int32_t N, class DF, class VF = hn::Vec<DF>, typename VType>
|
||||
|
|
@ -1207,7 +1207,7 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
|
|||
|
||||
i += 2 * NF;
|
||||
}
|
||||
HWY_DASSERT(qkv_dim == i);
|
||||
// HWY_DASSERT(qkv_dim == i);
|
||||
}
|
||||
|
||||
// See below for a specialized version for top-1 sampling.
|
||||
|
|
|
|||
Loading…
Reference in New Issue