diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 14e3849..5091946 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -394,7 +394,7 @@ class NibbleCodec { const V16 u8_2 = combine_u16_pair_to_8(in2); const V16 u8_3 = combine_u16_pair_to_8(in3); V8 packed; - if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { + if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) { // 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes // of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and // again with the second x2_1 gives 7654 3210. @@ -439,13 +439,13 @@ class NibbleCodec { // it may trigger asan errors from overrunning the end. We thus special-case // vector lengths, handling any non-constexpr, and constexpr <= 512 bit. V8 rep4; - if (HWY_HAVE_SCALABLE) { + if constexpr (HWY_HAVE_SCALABLE) { // Non constexpr length: 4 per whole block equals size/4. const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4); const V8 bytes = hn::LoadN(d8, packed, num_bytes); // Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc. - const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu)); - rep4 = hn::TableLookupBytes(bytes, idx); + const V8 idx = hn::ShiftRight<2>(hn::Iota(d8, 0)); + rep4 = hn::TableLookupLanes(bytes, hn::IndicesFromVec(d8, idx)); } else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed)); alignas(16) static constexpr uint8_t kRep4[16] = { diff --git a/compression/nuq_test.cc b/compression/nuq_test.cc index 70c9119..fe6e4f1 100644 --- a/compression/nuq_test.cc +++ b/compression/nuq_test.cc @@ -201,7 +201,7 @@ struct TestNormal { float centers[kClusters]; uint16_t indices[kGroupSize]; double elapsed = hwy::HighestValue(); - for (size_t rep = 0; rep < 100; ++rep) { + for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) { const double t0 = hwy::platform::Now(); const size_t unused_clusters = NuqClustering::ClusterExactL2( df, in.get(), kGroupSize, buf, centers, indices); @@ -278,6 +278,35 @@ void TestAllOffsetBF16() { test(hwy::bfloat16_t()); } +struct TestNibble { + template + HWY_INLINE void operator()(T /*unused*/, D d) { + using V = hn::Vec; + const size_t N = hn::Lanes(d); + const size_t num = 4 * N; + auto bytes = hwy::AllocateAligned(num / 2); + HWY_ASSERT(bytes); + const V v0 = hn::And(hn::Iota(d, 0), hn::Set(d, 15)); + const V v1 = hn::Set(d, 1); + const V v2 = hn::OddEven(v1, hn::Zero(d)); + const V v3 = hn::Reverse(d, v0); + NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3, bytes.get()); + const V out0 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 0 * N / 2); + const V out1 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 1 * N / 2); + const V out2 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 2 * N / 2); + const V out3 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 3 * N / 2); + HWY_ASSERT_VEC_EQ(d, v0, out0); + HWY_ASSERT_VEC_EQ(d, v1, out1); + HWY_ASSERT_VEC_EQ(d, v2, out2); + HWY_ASSERT_VEC_EQ(d, v3, out3); + } +}; + +void TestAllNibble() { + const hn::ForGEVectors<128, TestNibble> test; + test(uint16_t()); +} + struct TestStream { template HWY_INLINE void operator()(T /*unused*/, D d) { @@ -298,7 +327,7 @@ struct TestStream { ClusterBuf buf; double elapsed = hwy::HighestValue(); - for (size_t rep = 0; rep < 100; ++rep) { + for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) { const double t0 = hwy::platform::Now(); const size_t unused_clusters = NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0); @@ -310,7 +339,7 @@ struct TestStream { num * sizeof(float) * 1E-6 / elapsed); elapsed = hwy::HighestValue(); - for (size_t rep = 0; rep < 100; ++rep) { + for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) { const double t0 = hwy::platform::Now(); NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num); const double t1 = hwy::platform::Now(); @@ -379,7 +408,7 @@ struct TestDot { // Compute dot product without decompression. float actual = 0.0f; double elapsed = hwy::HighestValue(); - for (size_t rep = 0; rep < 20; ++rep) { + for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) { hn::Vec sum0 = hn::Zero(df); hn::Vec sum1 = hn::Zero(df); hn::Vec sum2 = hn::Zero(df); @@ -475,6 +504,7 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16); +HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16); HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);