mirror of https://github.com/google/gemma.cpp.git
Remove Bf16ToF32EO and use PromoteEvenTo and PromoteOddTo.
This commit is contained in:
parent
aa0b113214
commit
f608337fef
|
|
@ -51,20 +51,6 @@ namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
|
||||||
template <class DF, HWY_IF_F32_D(DF)>
|
|
||||||
HWY_INLINE void Bf16ToF32EO(const DF df32,
|
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT in,
|
|
||||||
hn::Vec<DF>& v_even,
|
|
||||||
hn::Vec<DF>& v_odd) {
|
|
||||||
const hn::Repartition<hwy::bfloat16_t, DF> dbf16;
|
|
||||||
const hn::RebindToUnsigned<decltype(df32)> du32;
|
|
||||||
|
|
||||||
const auto odd = Set(du32, 0xFFFF0000u);
|
|
||||||
const auto interleaved = BitCast(du32, LoadU(dbf16, in));
|
|
||||||
v_even = BitCast(df32, hn::ShiftLeft<16>(interleaved));
|
|
||||||
v_odd = BitCast(df32, And(interleaved, odd));
|
|
||||||
}
|
|
||||||
|
|
||||||
// Enables generic code independent of compression type.
|
// Enables generic code independent of compression type.
|
||||||
template <typename T> // primary, must specialize
|
template <typename T> // primary, must specialize
|
||||||
struct CompressTraits {};
|
struct CompressTraits {};
|
||||||
|
|
@ -263,19 +249,19 @@ struct CompressTraits<hwy::bfloat16_t> {
|
||||||
|
|
||||||
VF32 be0, bo0, be1, bo1;
|
VF32 be0, bo0, be1, bo1;
|
||||||
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
for (size_t i = 0; i < num; /* i += 2 * N */) {
|
||||||
|
const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||||
const VF32 ae0 = Load(df32, vec_aligned + i);
|
const VF32 ae0 = Load(df32, vec_aligned + i);
|
||||||
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
|
const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2));
|
||||||
Bf16ToF32EO(df32, in + in_ofs + i, be0, bo0);
|
sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0);
|
||||||
|
sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1);
|
||||||
i += N;
|
i += N;
|
||||||
sum0 = hn::MulAdd(ae0, be0, sum0);
|
|
||||||
sum1 = hn::MulAdd(ao0, bo0, sum1);
|
|
||||||
|
|
||||||
|
const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i);
|
||||||
const VF32 ae1 = Load(df32, vec_aligned + i);
|
const VF32 ae1 = Load(df32, vec_aligned + i);
|
||||||
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
|
const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2));
|
||||||
Bf16ToF32EO(df32, in + in_ofs + i, be1, bo1);
|
sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2);
|
||||||
|
sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3);
|
||||||
i += N;
|
i += N;
|
||||||
sum2 = hn::MulAdd(ae1, be1, sum2);
|
|
||||||
sum3 = hn::MulAdd(ao1, bo1, sum3);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sum0 = Add(sum0, sum1);
|
sum0 = Add(sum0, sum1);
|
||||||
|
|
|
||||||
|
|
@ -104,11 +104,10 @@ HWY_INLINE void ToEvenOddF32(
|
||||||
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
|
HWY_DASSERT(size % hn::Lanes(dbf16) == 0);
|
||||||
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
|
||||||
|
|
||||||
VF32 veven, vodd;
|
|
||||||
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
|
for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) {
|
||||||
Bf16ToF32EO(df, vec_aligned + i, veven, vodd);
|
const auto interleaved = hn::LoadU(dbf16, vec_aligned + i);
|
||||||
hn::Store(veven, df, out + i);
|
hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i);
|
||||||
hn::Store(vodd, df, out + i + hn::Lanes(df));
|
hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue