mirror of https://github.com/google/gemma.cpp.git
parent
a3d994915f
commit
fbd44cee42
|
|
@ -1328,7 +1328,6 @@ HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAddTileUpTo8_BF16(
|
|||
if constexpr (N >= 8) {
|
||||
Compress2(df, c7_p0, c7_p1, cs_span, 7 * kMaxLanes * 2);
|
||||
}
|
||||
VF zero = hn::Zero(df);
|
||||
size_t i = 0;
|
||||
HWY_DASSERT(qkv_dim % (NF * 2) == 0);
|
||||
while (i + NF * 2 <= qkv_dim) {
|
||||
|
|
|
|||
|
|
@ -481,8 +481,8 @@ decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
|
|||
for ([[maybe_unused]] auto&& mat : base) {
|
||||
HWY_DASSERT(mat.GetType() == type);
|
||||
}
|
||||
auto convert_to_matptr_t = [&base]<typename T>() {
|
||||
std::vector<MatPtrT<T>> matptrs;
|
||||
auto make_matptr_vec = [&base](auto element) {
|
||||
std::vector<MatPtrT<decltype(element)>> matptrs;
|
||||
matptrs.reserve(base.size());
|
||||
for (auto&& mat : base) {
|
||||
matptrs.emplace_back(mat);
|
||||
|
|
@ -490,12 +490,12 @@ decltype(auto) CallUpcastedKVs(hwy::Span<const MatPtr> base, const Func& func,
|
|||
return matptrs;
|
||||
};
|
||||
if (type == Type::kF32) {
|
||||
auto matptrs = convert_to_matptr_t.template operator()<float>();
|
||||
auto matptrs = make_matptr_vec(float{});
|
||||
hwy::Span<const MatPtrT<float>> matptrs_span(matptrs.data(),
|
||||
matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else if (type == Type::kBF16) {
|
||||
auto matptrs = convert_to_matptr_t.template operator()<BF16>();
|
||||
auto matptrs = make_matptr_vec(BF16{});
|
||||
hwy::Span<const MatPtrT<BF16>> matptrs_span(matptrs.data(), matptrs.size());
|
||||
return func(matptrs_span, std::forward<Args>(args)...);
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue