diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 7152960..438a1cc 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -33,6 +33,7 @@ #define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE #endif +#include "hwy/detect_targets.h" #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); @@ -157,9 +158,65 @@ class SfpCodec { return hn::IfThenZeroElse(is_zero, encoded); } + // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 3 ops (AVX-512). +#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE + template + static HWY_INLINE void DecBytes(D d, hn::Vec encoded, hn::Vec& lo, + hn::Vec& hi) { + const hn::Vec k80 = hn::Set(d, 0x80u); + HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved + + // Two 2x64 table lookups for lo/hi. + alignas(64) static constexpr uint8_t kTblL0[64] = { + 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, + 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, + 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, + 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, + 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, + 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0}; + alignas(64) static constexpr uint8_t kTblL1[64] = { + 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, + 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, + 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, + 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, + 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, + 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; + alignas(64) static constexpr uint8_t kTblH0[64] = { + 0x00, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x35, 0x35, 0x35, + 0x35, 0x35, 0x35, 0x35, 0x35, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, + 0x36, 0x36, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x38, + 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x39, 0x39, 0x39, 0x39, + 0x39, 0x39, 0x39, 0x39, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, + 0x3A, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B}; + alignas(64) static constexpr uint8_t kTblH1[64] = { + 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, + 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, + 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3E, + 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, + 0x3E, 0x3E, 0x3E, 0x3E, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, + 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F}; + const hn::Vec tblL0 = hn::LoadU(d, kTblL0); + const hn::Vec tblL1 = hn::LoadU(d, kTblL1); + const hn::Vec tblH0 = hn::LoadU(d, kTblH0); + const hn::Vec tblH1 = hn::LoadU(d, kTblH1); + // AVX-512 ignores the index MSB, no need to clear. + const hn::Indices512 idx{encoded.raw}; + hi = hn::TwoTablesLookupLanes(d, tblH0, tblH1, idx); + lo = hn::TwoTablesLookupLanes(d, tblL0, tblL1, idx); + hi = hn::OrAnd(hi, encoded, k80); // Insert sign bit + } + +// Generic is only required for partial vectors (too small for tables). +#undef SFP_IF_GENERIC_DEC +#define SFP_IF_GENERIC_DEC(D) HWY_IF_V_SIZE_LE_D(D, 32) +#else +// Always enable the generic decoder. +#undef SFP_IF_GENERIC_DEC +#define SFP_IF_GENERIC_DEC(D) void* yes = nullptr +#endif + // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops. - // Implementation detail, public because called by test. - template + template static HWY_INLINE void DecBytes(D d, hn::Vec encoded, hn::Vec& lo, hn::Vec& hi) { const hn::Vec k0 = hn::Zero(d); diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index c08945e..8e50098 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -67,6 +67,26 @@ float F32FromSFP8(uint32_t sfp) { return result; } +// Used for HWY_AVX3_DL and newer. +void PrintTables() { + if (HWY_ONCE && false) { + uint8_t hi[128]; + fprintf(stderr, "lo\n"); + for (uint32_t sfp = 0; sfp < 128; ++sfp) { + const uint32_t u = hwy::BitCastScalar(F32FromSFP8(sfp)); + // Lower bits are zero, hence we can truncate instead of rounding to bf16. + HWY_ASSERT((u & 0xFFFF) == 0); + fprintf(stderr, "0x%02X,", (u >> 16) & 0xFF); + hi[sfp] = u >> 24; + } + fprintf(stderr, "\nhi\n"); + for (uint32_t sfp = 0; sfp < 128; ++sfp) { + fprintf(stderr, "0x%02X,", hi[sfp]); + } + fprintf(stderr, "\n"); + } +} + void TestAllUnique() { std::set unique; for (uint32_t sfp = 0; sfp < 256; ++sfp) { @@ -497,6 +517,7 @@ HWY_AFTER_NAMESPACE(); namespace gcpp { HWY_BEFORE_TEST(SfpTest); +HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);