mirror of https://github.com/google/gemma.cpp.git
Fix NUQ for SVE - incorrect nibble packing
Also speed up test PiperOrigin-RevId: 670625545
This commit is contained in:
parent
aa11ddf5fc
commit
9661b81c4b
|
|
@ -394,7 +394,7 @@ class NibbleCodec {
|
||||||
const V16 u8_2 = combine_u16_pair_to_8(in2);
|
const V16 u8_2 = combine_u16_pair_to_8(in2);
|
||||||
const V16 u8_3 = combine_u16_pair_to_8(in3);
|
const V16 u8_3 = combine_u16_pair_to_8(in3);
|
||||||
V8 packed;
|
V8 packed;
|
||||||
if (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
if constexpr (HWY_TARGET <= HWY_AVX3_DL || !HWY_ARCH_X86) {
|
||||||
// 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes
|
// 8-bit ConcatEven is efficient. Let digits denote eight u8 lanes
|
||||||
// of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and
|
// of u8_1/0: ?d?3 ?c?2 / ?b?1 ?a?0. 8-bit ConcatEven = d3c2 b1a0, and
|
||||||
// again with the second x2_1 gives 7654 3210.
|
// again with the second x2_1 gives 7654 3210.
|
||||||
|
|
@ -439,13 +439,13 @@ class NibbleCodec {
|
||||||
// it may trigger asan errors from overrunning the end. We thus special-case
|
// it may trigger asan errors from overrunning the end. We thus special-case
|
||||||
// vector lengths, handling any non-constexpr, and constexpr <= 512 bit.
|
// vector lengths, handling any non-constexpr, and constexpr <= 512 bit.
|
||||||
V8 rep4;
|
V8 rep4;
|
||||||
if (HWY_HAVE_SCALABLE) {
|
if constexpr (HWY_HAVE_SCALABLE) {
|
||||||
// Non constexpr length: 4 per whole block equals size/4.
|
// Non constexpr length: 4 per whole block equals size/4.
|
||||||
const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4);
|
const size_t num_bytes = HWY_MAX(1, hn::Lanes(d8) / 4);
|
||||||
const V8 bytes = hn::LoadN(d8, packed, num_bytes);
|
const V8 bytes = hn::LoadN(d8, packed, num_bytes);
|
||||||
// Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc.
|
// Replicate bytes 4x: lowest 4 = 0, next 4 = 1 etc.
|
||||||
const V8 idx = hn::And(hn::Iota(d8, 0), hn::Set(d8, 0xFCu));
|
const V8 idx = hn::ShiftRight<2>(hn::Iota(d8, 0));
|
||||||
rep4 = hn::TableLookupBytes(bytes, idx);
|
rep4 = hn::TableLookupLanes(bytes, hn::IndicesFromVec(d8, idx));
|
||||||
} else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit
|
} else if (hn::MaxLanes(d16) <= 8) { // <= 128-bit
|
||||||
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
|
const V8 bytes = hn::ResizeBitCast(d8, hn::LoadU(d_load, packed));
|
||||||
alignas(16) static constexpr uint8_t kRep4[16] = {
|
alignas(16) static constexpr uint8_t kRep4[16] = {
|
||||||
|
|
|
||||||
|
|
@ -201,7 +201,7 @@ struct TestNormal {
|
||||||
float centers[kClusters];
|
float centers[kClusters];
|
||||||
uint16_t indices[kGroupSize];
|
uint16_t indices[kGroupSize];
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < 100; ++rep) {
|
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
const size_t unused_clusters = NuqClustering::ClusterExactL2(
|
||||||
df, in.get(), kGroupSize, buf, centers, indices);
|
df, in.get(), kGroupSize, buf, centers, indices);
|
||||||
|
|
@ -278,6 +278,35 @@ void TestAllOffsetBF16() {
|
||||||
test(hwy::bfloat16_t());
|
test(hwy::bfloat16_t());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct TestNibble {
|
||||||
|
template <typename T, class D>
|
||||||
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
const size_t N = hn::Lanes(d);
|
||||||
|
const size_t num = 4 * N;
|
||||||
|
auto bytes = hwy::AllocateAligned<uint8_t>(num / 2);
|
||||||
|
HWY_ASSERT(bytes);
|
||||||
|
const V v0 = hn::And(hn::Iota(d, 0), hn::Set(d, 15));
|
||||||
|
const V v1 = hn::Set(d, 1);
|
||||||
|
const V v2 = hn::OddEven(v1, hn::Zero(d));
|
||||||
|
const V v3 = hn::Reverse(d, v0);
|
||||||
|
NibbleCodec::OrderedPackU16(d, v0, v1, v2, v3, bytes.get());
|
||||||
|
const V out0 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 0 * N / 2);
|
||||||
|
const V out1 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 1 * N / 2);
|
||||||
|
const V out2 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 2 * N / 2);
|
||||||
|
const V out3 = NibbleCodec::OrderedUnpackU16(d, bytes.get() + 3 * N / 2);
|
||||||
|
HWY_ASSERT_VEC_EQ(d, v0, out0);
|
||||||
|
HWY_ASSERT_VEC_EQ(d, v1, out1);
|
||||||
|
HWY_ASSERT_VEC_EQ(d, v2, out2);
|
||||||
|
HWY_ASSERT_VEC_EQ(d, v3, out3);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
void TestAllNibble() {
|
||||||
|
const hn::ForGEVectors<128, TestNibble> test;
|
||||||
|
test(uint16_t());
|
||||||
|
}
|
||||||
|
|
||||||
struct TestStream {
|
struct TestStream {
|
||||||
template <typename T, class D>
|
template <typename T, class D>
|
||||||
HWY_INLINE void operator()(T /*unused*/, D d) {
|
HWY_INLINE void operator()(T /*unused*/, D d) {
|
||||||
|
|
@ -298,7 +327,7 @@ struct TestStream {
|
||||||
|
|
||||||
ClusterBuf buf;
|
ClusterBuf buf;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < 100; ++rep) {
|
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
const size_t unused_clusters =
|
const size_t unused_clusters =
|
||||||
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
NuqCodec::Enc(df, in.get(), num, buf, num, nuq.get(), 0);
|
||||||
|
|
@ -310,7 +339,7 @@ struct TestStream {
|
||||||
num * sizeof(float) * 1E-6 / elapsed);
|
num * sizeof(float) * 1E-6 / elapsed);
|
||||||
|
|
||||||
elapsed = hwy::HighestValue<double>();
|
elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < 100; ++rep) {
|
for (size_t rep = 0; rep < hn::AdjustedReps(40); ++rep) {
|
||||||
const double t0 = hwy::platform::Now();
|
const double t0 = hwy::platform::Now();
|
||||||
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
NuqCodec::Dec(d, num, nuq.get(), 0, out.get(), num);
|
||||||
const double t1 = hwy::platform::Now();
|
const double t1 = hwy::platform::Now();
|
||||||
|
|
@ -379,7 +408,7 @@ struct TestDot {
|
||||||
// Compute dot product without decompression.
|
// Compute dot product without decompression.
|
||||||
float actual = 0.0f;
|
float actual = 0.0f;
|
||||||
double elapsed = hwy::HighestValue<double>();
|
double elapsed = hwy::HighestValue<double>();
|
||||||
for (size_t rep = 0; rep < 20; ++rep) {
|
for (size_t rep = 0; rep < hn::AdjustedReps(20); ++rep) {
|
||||||
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
|
||||||
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
|
||||||
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
|
||||||
|
|
@ -475,6 +504,7 @@ HWY_EXPORT_AND_TEST_P(NuqTest, TestAllRamp);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNormal);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetF32);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllOffsetBF16);
|
||||||
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllNibble);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamF32);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllStreamBF16);
|
||||||
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
|
HWY_EXPORT_AND_TEST_P(NuqTest, TestAllDotF32);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue