mirror of https://github.com/google/gemma.cpp.git
Add Decompress2AndCompressInplace helper
PiperOrigin-RevId: 825966142
This commit is contained in:
parent
006999063c
commit
ee7d79c0a6
|
|
@ -604,6 +604,13 @@ HWY_INLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
|||
Traits::DecompressAndZeroPad(d, MakeConst(packed), packed_ofs, raw, num);
|
||||
}
|
||||
|
||||
// NOTE: the following are the recommended way to iterate over arrays of
|
||||
// potentially compressed elements, including remainder handling. Prefer them
|
||||
// over calling `Decompress2` directly, which does not handle remainders.
|
||||
// `DecompressAndCall` is for algorithms expressed as `Kernel` objects, such as
|
||||
// `Dot`. `Decompress*AndCompress*` are for varying numbers of input arrays and
|
||||
// user code expressed as lambdas.
|
||||
|
||||
// Invokes `kernel` for the `v.num` elements of `w` and `v`. Decompresses from
|
||||
// both into groups of four vectors with lane type `Kernel::Raw`, passes them to
|
||||
// `kernel.Update4`; loads the final vector(s) with zero-padding, then passes
|
||||
|
|
@ -733,8 +740,8 @@ HWY_INLINE float DecompressAndCall(D, const PackedSpan<const VT> v,
|
|||
comp3);
|
||||
}
|
||||
|
||||
// Similar to `hn::Transform*`, but for compressed `T`. Used by ops-inl.h.
|
||||
// `DF` is the decompressed type, typically `float`.
|
||||
// Similar to `hn::Transform*`, but for compressed `T`. Used by `ops-inl.h`.
|
||||
// `DF` is the decompressed type, typically `float`. Calls `func(df, v_inout)`.
|
||||
template <class DF, typename T, class Func>
|
||||
HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||
size_t num, Func&& func) {
|
||||
|
|
@ -773,6 +780,7 @@ HWY_INLINE void DecompressAndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
|||
}
|
||||
|
||||
// One extra argument. `DF` is the decompressed type, typically `float`.
|
||||
// Calls `func(df, v_inout, v1)`.
|
||||
template <class DF, typename T, typename T1, class Func>
|
||||
HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
||||
size_t num,
|
||||
|
|
@ -821,8 +829,64 @@ HWY_INLINE void Decompress1AndCompressInplace(DF df, T* HWY_RESTRICT inout,
|
|||
}
|
||||
}
|
||||
|
||||
// Two extra arguments. `DF` is the decompressed type, typically `float`.
|
||||
// Calls `func(df, v_inout, v1, v2)`.
|
||||
template <class DF, typename T, typename T1, typename T2, class Func>
|
||||
HWY_INLINE void Decompress2AndCompressInplace(
|
||||
DF df, T* HWY_RESTRICT inout, size_t num, const T1* HWY_RESTRICT p1,
|
||||
const T2* HWY_RESTRICT p2, const size_t p2_ofs, Func&& func) {
|
||||
const auto packed_inout = MakeSpan(inout, num);
|
||||
const auto packed1 = MakeSpan(p1, num);
|
||||
const auto packed2 = MakeSpan(p2, p2_ofs + num);
|
||||
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
HWY_LANES_CONSTEXPR const size_t NF = hn::Lanes(df);
|
||||
size_t i = 0;
|
||||
if (num >= 2 * NF) {
|
||||
for (; i <= num - 2 * NF; i += 2 * NF) {
|
||||
VF v0, v1;
|
||||
Decompress2(df, packed_inout, i, v0, v1);
|
||||
VF v10, v11;
|
||||
Decompress2(df, packed1, i, v10, v11);
|
||||
VF v20, v21;
|
||||
Decompress2(df, packed2, p2_ofs + i, v20, v21);
|
||||
const VF out0 = func(df, v0, v10, v20);
|
||||
const VF out1 = func(df, v1, v11, v21);
|
||||
Compress2(df, out0, out1, packed_inout, i);
|
||||
}
|
||||
}
|
||||
|
||||
const size_t remaining = num - i;
|
||||
HWY_DASSERT(remaining < 2 * NF);
|
||||
if (HWY_UNLIKELY(remaining != 0)) {
|
||||
HWY_ALIGN float buf_inout[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf1[2 * hn::MaxLanes(df)];
|
||||
HWY_ALIGN float buf2[2 * hn::MaxLanes(df)];
|
||||
// Ensure the second vector is zeroed even if remaining <= NF.
|
||||
hn::Store(hn::Zero(df), df, buf_inout + NF);
|
||||
hn::Store(hn::Zero(df), df, buf1 + NF);
|
||||
hn::Store(hn::Zero(df), df, buf2 + NF);
|
||||
DecompressAndZeroPad(df, packed_inout, i, buf_inout, remaining);
|
||||
DecompressAndZeroPad(df, packed1, i, buf1, remaining);
|
||||
DecompressAndZeroPad(df, packed2, p2_ofs + i, buf2, remaining);
|
||||
const VF v0 = hn::Load(df, buf_inout);
|
||||
const VF v1 = hn::Load(df, buf_inout + NF);
|
||||
const VF v10 = hn::Load(df, buf1);
|
||||
const VF v11 = hn::Load(df, buf1 + NF);
|
||||
const VF v20 = hn::Load(df, buf2);
|
||||
const VF v21 = hn::Load(df, buf2 + NF);
|
||||
const VF out0 = func(df, v0, v10, v20);
|
||||
const VF out1 = func(df, v1, v11, v21);
|
||||
Compress2(df, out0, out1, MakeSpan(buf_inout, 2 * NF), 0);
|
||||
// Clang generates incorrect code for CopyBytes if num = 2.
|
||||
for (size_t j = 0; j < remaining; ++j) {
|
||||
inout[i + j] = hwy::ConvertScalarTo<T>(buf_inout[j]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Single input, separate output. `DF` is the decompressed type, typically
|
||||
// `float`.
|
||||
// `float`. Calls `func(df, v1)`.
|
||||
template <class DF, typename T, typename T1, class Func>
|
||||
HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||
const T1* HWY_RESTRICT p1,
|
||||
|
|
@ -863,7 +927,8 @@ HWY_INLINE void Decompress1AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Two inputs. `DF` is the decompressed type, typically `float`.
|
||||
// Two inputs, separate output. `DF` is the decompressed type, typically
|
||||
// `float`. Calls `func(df, v1, v2)`.
|
||||
template <class DF, typename T, typename T1, typename T2, class Func>
|
||||
HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||
const T1* HWY_RESTRICT p1,
|
||||
|
|
@ -912,7 +977,8 @@ HWY_INLINE void Decompress2AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
|||
}
|
||||
}
|
||||
|
||||
// Three inputs. `DF` is the decompressed type, typically `float`.
|
||||
// Three inputs, separate output. `DF` is the decompressed type, typically
|
||||
// `float`. Calls `func(df, v1, v2, v3)`.
|
||||
template <class DF, typename T, typename T1, typename T2, typename T3,
|
||||
class Func>
|
||||
HWY_INLINE void Decompress3AndCompressTo(DF df, T* HWY_RESTRICT out, size_t num,
|
||||
|
|
|
|||
|
|
@ -259,6 +259,13 @@ class TestDecompressAndCompress {
|
|||
[](DF, VF v, VF v1) HWY_ATTR -> VF { return hn::Add(v, v1); });
|
||||
HWY_ASSERT_ARRAY_EQ(expected2.get(), out.get(), num);
|
||||
|
||||
// `out` already contains v + v1.
|
||||
Decompress2AndCompressInplace(
|
||||
df, out.get(), num, p1.get(), p2.get(), /*p2_ofs=*/0,
|
||||
[](DF, VF v, VF /*v1*/, VF v2)
|
||||
HWY_ATTR -> VF { return hn::Add(v, v2); });
|
||||
HWY_ASSERT_ARRAY_EQ(expected3.get(), out.get(), num);
|
||||
|
||||
Decompress1AndCompressTo(df, out.get(), num, p.get(),
|
||||
[](DF, VF v) HWY_ATTR -> VF { return v; });
|
||||
HWY_ASSERT_ARRAY_EQ(expected1.get(), out.get(), num);
|
||||
|
|
|
|||
Loading…
Reference in New Issue