Fix NUQ for SVE - incorrect nibble packing

Also speed up test

PiperOrigin-RevId: 670625545
This commit is contained in:
Jan Wassenberg 2024-09-03 10:58:27 -07:00 committed by Copybara-Service
parent aa11ddf5fc
commit 9661b81c4b
2 changed files with 38 additions and 8 deletions

View File

@ -394,7 +394,7 @@ class NibbleCodec {
const V16 u8_2 = combine_u16_pair_to_8(in2);
const V16 u8_3 = combine_u16_pair_to_8(in3);
V8 packed;
if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
// 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes
// of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and
// again with the second x2_1 gives 7654 3210.
@ -439,13 +439,13 @@ class NibbleCodec {
// it may trigger asan errors from overrunning the end. We thus special-case
// vector lengths, handling any non-constexpr, and constexpr <= 512 bit.
V8 rep4;
if (HWY_HAVE_SCALABLE) {
if constexpr (HWY_HAVE_SCALABLE) {
// Non constexpr length: 4 per whole block equals size/4.
const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4);
const V8 bytes = hn::LoadN(d8, packed, num_bytes);
// Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc.
const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu));
rep4 = hn::TableLookupBytes(bytes, idx);
const V8 idx = hn::ShiftRight<2>(hn::Iota(d8, 0));
rep4 = hn::TableLookupLanes(bytes, hn::IndicesFromVec(d8, idx));
} else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
alignas(16) static constexpr uint8_t kRep4[16] = {

View File

@ -201,7 +201,7 @@ struct TestNormal {
float centers[kClusters];
uint16_t indices[kGroupSize];
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
const double t0 = hwy::platform::Now();
const size_t unused_clusters = NuqClustering::ClusterExactL2(
df, in.get(), kGroupSize, buf, centers, indices);
@ -278,6 +278,35 @@ void TestAllOffsetBF16() {
test(hwy::bfloat16_t());
}
struct TestNibble {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
using V = hn::Vec<decltype(d)>;
const size_t N = hn::Lanes(d);
const size_t num = 4 * N;
auto bytes = hwy::AllocateAligned<uint8_t>(num / 2);
HWY_ASSERT(bytes);
const V v0 = hn::And(hn::Iota(d, 0), hn::Set(d, 15));
const V v1 = hn::Set(d, 1);
const V v2 = hn::OddEven(v1, hn::Zero(d));
const V v3 = hn::Reverse(d, v0);
NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3, bytes.get());
const V out0 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 0 * N / 2);
const V out1 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 1 * N / 2);
const V out2 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 2 * N / 2);
const V out3 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 3 * N / 2);
HWY_ASSERT_VEC_EQ(d, v0, out0);
HWY_ASSERT_VEC_EQ(d, v1, out1);
HWY_ASSERT_VEC_EQ(d, v2, out2);
HWY_ASSERT_VEC_EQ(d, v3, out3);
}
};
void TestAllNibble() {
const hn::ForGEVectors<128, TestNibble> test;
test(uint16_t());
}
struct TestStream {
template <typename T, class D>
HWY_INLINE void operator()(T /*unused*/, D d) {
@ -298,7 +327,7 @@ struct TestStream {
ClusterBuf buf;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
const double t0 = hwy::platform::Now();
const size_t unused_clusters =
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
@ -310,7 +339,7 @@ struct TestStream {
num * sizeof(float) * 1E-6 / elapsed);
elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 100; ++rep) {
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
const double t0 = hwy::platform::Now();
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
const double t1 = hwy::platform::Now();
@ -379,7 +408,7 @@ struct TestDot {
// Compute dot product without decompression.
float actual = 0.0f;
double elapsed = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 20; ++rep) {
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
@ -475,6 +504,7 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);