diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 7ae43b7..eb87cf8 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -51,20 +51,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -template -HWY_INLINE void Bf16ToF32EO(const DF df32, - const hwy::bfloat16_t* HWY_RESTRICT in, - hn::Vec& v_even, - hn::Vec& v_odd) { - const hn::Repartition dbf16; - const hn::RebindToUnsigned du32; - - const auto odd = Set(du32, 0xFFFF0000u); - const auto interleaved = BitCast(du32, LoadU(dbf16, in)); - v_even = BitCast(df32, hn::ShiftLeft<16>(interleaved)); - v_odd = BitCast(df32, And(interleaved, odd)); -} - // Enables generic code independent of compression type. template // primary, must specialize struct CompressTraits {}; @@ -263,19 +249,19 @@ struct CompressTraits { VF32 be0, bo0, be1, bo1; for (size_t i = 0; i < num; /* i += 2 * N */) { + const auto interleaved0 = hn::LoadU(dbf16, in + in_ofs + i); const VF32 ae0 = Load(df32, vec_aligned + i); const VF32 ao0 = Load(df32, vec_aligned + i + (N / 2)); - Bf16ToF32EO(df32, in + in_ofs + i, be0, bo0); + sum0 = hn::MulAdd(ae0, hn::PromoteEvenTo(df32, interleaved0), sum0); + sum1 = hn::MulAdd(ao0, hn::PromoteOddTo(df32, interleaved0), sum1); i += N; - sum0 = hn::MulAdd(ae0, be0, sum0); - sum1 = hn::MulAdd(ao0, bo0, sum1); + const auto interleaved1 = hn::LoadU(dbf16, in + in_ofs + i); const VF32 ae1 = Load(df32, vec_aligned + i); const VF32 ao1 = Load(df32, vec_aligned + i + (N / 2)); - Bf16ToF32EO(df32, in + in_ofs + i, be1, bo1); + sum2 = hn::MulAdd(ae1, hn::PromoteEvenTo(df32, interleaved1), sum2); + sum3 = hn::MulAdd(ao1, hn::PromoteOddTo(df32, interleaved1), sum3); i += N; - sum2 = hn::MulAdd(ae1, be1, sum2); - sum3 = hn::MulAdd(ao1, bo1, sum3); } sum0 = Add(sum0, sum1); diff --git a/gemma/ops.h b/gemma/ops.h index 0c29cda..1b1c29e 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -104,11 +104,10 @@ HWY_INLINE void ToEvenOddF32( HWY_DASSERT(size % hn::Lanes(dbf16) == 0); HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - VF32 veven, vodd; for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) { - Bf16ToF32EO(df, vec_aligned + i, veven, vodd); - hn::Store(veven, df, out + i); - hn::Store(vodd, df, out + i + hn::Lanes(df)); + const auto interleaved = hn::LoadU(dbf16, vec_aligned + i); + hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i); + hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df)); } }