diff --git a/compression/compress-inl.h b/compression/compress-inl.h index 10ce57c..42812ef 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -604,6 +604,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan& packed, Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num); } +// NOTE: the following are the recommended way to iterate over arrays of +// potentially compressed elements, including remainder handling. Prefer them +// over calling `Decompress2` directly, which does not handle remainders. +// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as +// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and +// user code expressed as lambdas. + // Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from // both into groups of four vectors with lane type `Kernel::Raw`, passes them to // `kernel.Update4`; loads the final vector(s) with zero-padding, then passes @@ -733,8 +740,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan v, comp3); } -// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h. -// `DF` is the decompressed type, typically `float`. +// Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`. +// `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`. template HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout, size_t num, Func&& func) { @@ -773,6 +780,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout, } // One extra argument. `DF` is the decompressed type, typically `float`. +// Calls `func(df, v_inout, v1)`. template HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, size_t num, @@ -821,8 +829,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout, } } +// Two extra arguments. `DF` is the decompressed type, typically `float`. +// Calls `func(df, v_inout, v1, v2)`. +template +HWY_INLINE void Decompress2AndCompressInplace( + DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1, + const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) { + const auto packed_inout = MakeSpan(inout, num); + const auto packed1 = MakeSpan(p1, num); + const auto packed2 = MakeSpan(p2, p2_ofs + num); + + using VF = hn::Vec; + HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df); + size_t i = 0; + if (num >= 2 * NF) { + for (; i <= num - 2 * NF; i += 2 * NF) { + VF v0, v1; + Decompress2(df, packed_inout, i, v0, v1); + VF v10, v11; + Decompress2(df, packed1, i, v10, v11); + VF v20, v21; + Decompress2(df, packed2, p2_ofs + i, v20, v21); + const VF out0 = func(df, v0, v10, v20); + const VF out1 = func(df, v1, v11, v21); + Compress2(df, out0, out1, packed_inout, i); + } + } + + const size_t remaining = num - i; + HWY_DASSERT(remaining < 2 * NF); + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf1[2 * hn::MaxLanes(df)]; + HWY_ALIGN float buf2[2 * hn::MaxLanes(df)]; + // Ensure the second vector is zeroed even if remaining <= NF. + hn::Store(hn::Zero(df), df, buf_inout + NF); + hn::Store(hn::Zero(df), df, buf1 + NF); + hn::Store(hn::Zero(df), df, buf2 + NF); + DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining); + DecompressAndZeroPad(df, packed1, i, buf1, remaining); + DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining); + const VF v0 = hn::Load(df, buf_inout); + const VF v1 = hn::Load(df, buf_inout + NF); + const VF v10 = hn::Load(df, buf1); + const VF v11 = hn::Load(df, buf1 + NF); + const VF v20 = hn::Load(df, buf2); + const VF v21 = hn::Load(df, buf2 + NF); + const VF out0 = func(df, v0, v10, v20); + const VF out1 = func(df, v1, v11, v21); + Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0); + // Clang generates incorrect code for CopyBytes if num = 2. + for (size_t j = 0; j < remaining; ++j) { + inout[i + j] = hwy::ConvertScalarTo(buf_inout[j]); + } + } +} + // Single input, separate output. `DF` is the decompressed type, typically -// `float`. +// `float`. Calls `func(df, v1)`. template HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, const T1* HWY_RESTRICT p1, @@ -863,7 +927,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, } } -// Two inputs. `DF` is the decompressed type, typically `float`. +// Two inputs, separate output. `DF` is the decompressed type, typically +// `float`. Calls `func(df, v1, v2)`. template HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, const T1* HWY_RESTRICT p1, @@ -912,7 +977,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, } } -// Three inputs. `DF` is the decompressed type, typically `float`. +// Three inputs, separate output. `DF` is the decompressed type, typically +// `float`. Calls `func(df, v1, v2, v3)`. template HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num, diff --git a/compression/compress_test.cc b/compression/compress_test.cc index 987f409..421492e 100644 --- a/compression/compress_test.cc +++ b/compression/compress_test.cc @@ -259,6 +259,13 @@ class TestDecompressAndCompress { [](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); }); HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num); + // `out` already contains v + v1. + Decompress2AndCompressInplace( + df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0, + [](DF, VF v, VF /*v1*/, VF v2) + HWY_ATTR -> VF { return hn::Add(v, v2); }); + HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num); + Decompress1AndCompressTo(df, out.get(), num, p.get(), [](DF, VF v) HWY_ATTR -> VF { return v; }); HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);