mirror of https://github.com/google/gemma.cpp.git
Fix NUQ for SVE - incorrect nibble packing
Also speed up test PiperOrigin-RevId: 670625545
This commit is contained in:
parent
aa11ddf5fc
commit
9661b81c4b
|
|
@ -394,7 +394,7 @@ class NibbleCodec {
|
|||
const V16 u8_2 = combine_u16_pair_to_8(in2);
|
||||
const V16 u8_3 = combine_u16_pair_to_8(in3);
|
||||
V8 packed;
|
||||
if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
||||
if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
||||
// 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes
|
||||
// of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and
|
||||
// again with the second x2_1 gives 7654 3210.
|
||||
|
|
@ -439,13 +439,13 @@ class NibbleCodec {
|
|||
// it may trigger asan errors from overrunning the end. We thus special-case
|
||||
// vector lengths, handling any non-constexpr, and constexpr <= 512 bit.
|
||||
V8 rep4;
|
||||
if (HWY_HAVE_SCALABLE) {
|
||||
if constexpr (HWY_HAVE_SCALABLE) {
|
||||
// Non constexpr length: 4 per whole block equals size/4.
|
||||
const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4);
|
||||
const V8 bytes = hn::LoadN(d8, packed, num_bytes);
|
||||
// Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc.
|
||||
const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu));
|
||||
rep4 = hn::TableLookupBytes(bytes, idx);
|
||||
const V8 idx = hn::ShiftRight<2>(hn::Iota(d8, 0));
|
||||
rep4 = hn::TableLookupLanes(bytes, hn::IndicesFromVec(d8, idx));
|
||||
} else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit
|
||||
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
|
||||
alignas(16) static constexpr uint8_t kRep4[16] = {
|
||||
|
|
|
|||
|
|
@ -201,7 +201,7 @@ struct TestNormal {
|
|||
float centers[kClusters];
|
||||
uint16_t indices[kGroupSize];
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||
df, in.get(), kGroupSize, buf, centers, indices);
|
||||
|
|
@ -278,6 +278,35 @@ void TestAllOffsetBF16() {
|
|||
test(hwy::bfloat16_t());
|
||||
}
|
||||
|
||||
struct TestNibble {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const size_t N = hn::Lanes(d);
|
||||
const size_t num = 4 * N;
|
||||
auto bytes = hwy::AllocateAligned<uint8_t>(num / 2);
|
||||
HWY_ASSERT(bytes);
|
||||
const V v0 = hn::And(hn::Iota(d, 0), hn::Set(d, 15));
|
||||
const V v1 = hn::Set(d, 1);
|
||||
const V v2 = hn::OddEven(v1, hn::Zero(d));
|
||||
const V v3 = hn::Reverse(d, v0);
|
||||
NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3, bytes.get());
|
||||
const V out0 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 0 * N / 2);
|
||||
const V out1 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 1 * N / 2);
|
||||
const V out2 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 2 * N / 2);
|
||||
const V out3 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 3 * N / 2);
|
||||
HWY_ASSERT_VEC_EQ(d, v0, out0);
|
||||
HWY_ASSERT_VEC_EQ(d, v1, out1);
|
||||
HWY_ASSERT_VEC_EQ(d, v2, out2);
|
||||
HWY_ASSERT_VEC_EQ(d, v3, out3);
|
||||
}
|
||||
};
|
||||
|
||||
void TestAllNibble() {
|
||||
const hn::ForGEVectors<128, TestNibble> test;
|
||||
test(uint16_t());
|
||||
}
|
||||
|
||||
struct TestStream {
|
||||
template <typename T, class D>
|
||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||
|
|
@ -298,7 +327,7 @@ struct TestStream {
|
|||
|
||||
ClusterBuf buf;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
const size_t unused_clusters =
|
||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||
|
|
@ -310,7 +339,7 @@ struct TestStream {
|
|||
num * sizeof(float) * 1E-6 / elapsed);
|
||||
|
||||
elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 100; ++rep) {
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||
const double t0 = hwy::platform::Now();
|
||||
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
||||
const double t1 = hwy::platform::Now();
|
||||
|
|
@ -379,7 +408,7 @@ struct TestDot {
|
|||
// Compute dot product without decompression.
|
||||
float actual = 0.0f;
|
||||
double elapsed = hwy::HighestValue<double>();
|
||||
for (size_t rep = 0; rep < 20; ++rep) {
|
||||
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||
|
|
@ -475,6 +504,7 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
|||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
|
||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
|
||||
|
|
|
|||
Loading…
Reference in New Issue