mirror of https://github.com/google/gemma.cpp.git
Remove Bf16ToF32EO and use PromoteEvenTo and PromoteOddTo.
This commit is contained in:
parent
aa0b113214
commit
f608337fef
|
|
@ -51,20 +51,6 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
HWY_INLINE void Bf16ToF32EO(const DF df32,
|
||||
const hwy::bfloat16_t* HWY_RESTRICT in,
|
||||
hn::Vec<DF>& v_even,
|
||||
hn::Vec<DF>& v_odd) {
|
||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
|
||||
const hn::RebindToUnsigned<decltype(df32)> du32;
|
||||
|
||||
const auto odd = Set(du32, 0xFFFF0000u);
|
||||
const auto interleaved = BitCast(du32, LoadU(dbf16, in));
|
||||
v_even = BitCast(df32, hn::ShiftLeft<16>(interleaved));
|
||||
v_odd = BitCast(df32, And(interleaved, odd));
|
||||
}
|
||||
|
||||
// Enables generic code independent of compression type.
|
||||
template <typename T> // primary, must specialize
|
||||
struct CompressTraits {};
|
||||
|
|
@ -263,19 +249,19 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
|
||||
VF32 be0, bo0, be1, bo1;
|
||||
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
||||
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||
const VF32 ae0 = Load(df32, vec_aligned + i);
|
||||
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
|
||||
Bf16ToF32EO(df32, in + in_ofs + i, be0, bo0);
|
||||
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
|
||||
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
|
||||
i += N;
|
||||
sum0 = hn::MulAdd(ae0, be0, sum0);
|
||||
sum1 = hn::MulAdd(ao0, bo0, sum1);
|
||||
|
||||
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||
const VF32 ae1 = Load(df32, vec_aligned + i);
|
||||
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
|
||||
Bf16ToF32EO(df32, in + in_ofs + i, be1, bo1);
|
||||
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
|
||||
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
|
||||
i += N;
|
||||
sum2 = hn::MulAdd(ae1, be1, sum2);
|
||||
sum3 = hn::MulAdd(ao1, bo1, sum3);
|
||||
}
|
||||
|
||||
sum0 = Add(sum0, sum1);
|
||||
|
|
|
|||
|
|
@ -104,11 +104,10 @@ HWY_INLINE void ToEvenOddF32(
|
|||
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
|
||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||
|
||||
VF32 veven, vodd;
|
||||
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
|
||||
Bf16ToF32EO(df, vec_aligned + i, veven, vodd);
|
||||
hn::Store(veven, df, out + i);
|
||||
hn::Store(vodd, df, out + i + hn::Lanes(df));
|
||||
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
|
||||
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
|
||||
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue