Add Decompress2AndCompressInplace helper

PiperOrigin-RevId: 825966142
This commit is contained in:
Jan Wassenberg 2025-10-30 04:04:08 -07:00 committed by Copybara-Service
parent 006999063c
commit ee7d79c0a6
2 changed files with 78 additions and 5 deletions

View File

@ -604,6 +604,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
}
// NOTE: the following are the recommended way to iterate over arrays of
// potentially compressed elements, including remainder handling. Prefer them
// over calling `Decompress2` directly, which does not handle remainders.
// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as
// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and
// user code expressed as lambdas.
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
@ -733,8 +740,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
comp3);
}
// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h.
// `DF` is the decompressed type, typically `float`.
// Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`.
// `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`.
template <class DF, typename T, class Func>
HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num, Func&& func) {
@ -773,6 +780,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
}
// One extra argument. `DF` is the decompressed type, typically `float`.
// Calls `func(df, v_inout, v1)`.
template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
size_t num,
@ -821,8 +829,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
}
}
// Two extra arguments. `DF` is the decompressed type, typically `float`.
// Calls `func(df, v_inout, v1, v2)`.
template <class DF, typename T, typename T1, typename T2, class Func>
HWY_INLINE void Decompress2AndCompressInplace(
DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1,
const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) {
const auto packed_inout = MakeSpan(inout, num);
const auto packed1 = MakeSpan(p1, num);
const auto packed2 = MakeSpan(p2, p2_ofs + num);
using VF = hn::Vec<decltype(df)>;
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 2 * NF) {
for (; i <= num - 2 * NF; i += 2 * NF) {
VF v0, v1;
Decompress2(df, packed_inout, i, v0, v1);
VF v10, v11;
Decompress2(df, packed1, i, v10, v11);
VF v20, v21;
Decompress2(df, packed2, p2_ofs + i, v20, v21);
const VF out0 = func(df, v0, v10, v20);
const VF out1 = func(df, v1, v11, v21);
Compress2(df, out0, out1, packed_inout, i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * NF);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
// Ensure the second vector is zeroed even if remaining <= NF.
hn::Store(hn::Zero(df), df, buf_inout + NF);
hn::Store(hn::Zero(df), df, buf1 + NF);
hn::Store(hn::Zero(df), df, buf2 + NF);
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
const VF v0 = hn::Load(df, buf_inout);
const VF v1 = hn::Load(df, buf_inout + NF);
const VF v10 = hn::Load(df, buf1);
const VF v11 = hn::Load(df, buf1 + NF);
const VF v20 = hn::Load(df, buf2);
const VF v21 = hn::Load(df, buf2 + NF);
const VF out0 = func(df, v0, v10, v20);
const VF out1 = func(df, v1, v11, v21);
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
// Clang generates incorrect code for CopyBytes if num = 2.
for (size_t j = 0; j < remaining; ++j) {
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
}
}
}
// Single input, separate output. `DF` is the decompressed type, typically
// `float`.
// `float`. Calls `func(df, v1)`.
template <class DF, typename T, typename T1, class Func>
HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1,
@ -863,7 +927,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
}
}
// Two inputs. `DF` is the decompressed type, typically `float`.
// Two inputs, separate output. `DF` is the decompressed type, typically
// `float`. Calls `func(df, v1, v2)`.
template <class DF, typename T, typename T1, typename T2, class Func>
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
const T1* HWY_RESTRICT p1,
@ -912,7 +977,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
}
}
// Three inputs. `DF` is the decompressed type, typically `float`.
// Three inputs, separate output. `DF` is the decompressed type, typically
// `float`. Calls `func(df, v1, v2, v3)`.
template <class DF, typename T, typename T1, typename T2, typename T3,
class Func>
HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,

View File

@ -259,6 +259,13 @@ class TestDecompressAndCompress {
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
// `out` already contains v + v1.
Decompress2AndCompressInplace(
df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0,
[](DF, VF v, VF /*v1*/, VF v2)
HWY_ATTR -> VF { return hn::Add(v, v2); });
HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num);
Decompress1AndCompressTo(df, out.get(), num, p.get(),
[](DF, VF v) HWY_ATTR -> VF { return v; });
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);