diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 3be36ec..7ca877b 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -218,7 +218,7 @@ class SfpCodec { #define SFP_IF_GENERIC_DEC(D) void* yes = nullptr #endif - // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops. + // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 9 ops. template static HWY_INLINE void DecBytes(D d, hn::Vec encoded, hn::Vec& lo, hn::Vec& hi) { @@ -233,31 +233,27 @@ class SfpCodec { // Special-case zero, negated so we can use MaskedAddOr. Signed comparison // is fine because we have cleared the sign bit. const hn::Mask is_nonzero = SignedGt(d, encoded, k0); - // If MSB is clear, we have two mantissa bits, otherwise three. + // If bit 6 is clear, we have two mantissa bits, otherwise three. const hn::Mask is_small_e = SignedLt(d, encoded, hn::Set(d, 64)); - // If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm. - const hn::Vec e4m3 = - hn::MaskedAddOr(encoded, is_small_e, encoded, encoded); - HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80))); - const hn::Vec e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only - HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u)))); - // The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the - // binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest - // encoded value (0) should be less than the lowest 'large' exponent 2^-7. - const hn::Vec e_bias = hn::IfThenElse( - is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u)); - // Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but - // we must zero e_bias to get the desired all-zero bf16. - const hn::Vec biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e); - // The decoded binary32 exponent should be at most 2^0. - HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); + // For encoded in [1, 8), hi = 0x34; encoded = 0x40 => hi = 0x3C including + // (encoded >> 4) == 4, so add 0x38. + const hn::Vec e_bias = + hn::IfThenElse(is_small_e, hn::Set(d, 0x34), hn::Set(d, 0x38)); - // Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa. - const hn::Vec m7 = hn::ShiftLeft<4>(e4m3); - // Lower byte of bf16 = exponent LSB || mantissa. - lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7); - // Upper byte of bf16 = sign || lower 7 bits of exponent. - hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e)); + // The low byte of bf16 is encoded << (is_small_e ? 5 : 4). + const hn::Vec shl1_if_small = + hn::MaskedAddOr(encoded, is_small_e, encoded, encoded); + lo = hn::ShiftLeft<4>(shl1_if_small); + // Lower 4 bits always zero. + HWY_DASSERT(hn::AllTrue(d, hn::Eq(hn::And(lo, Set(d, 15u)), hn::Zero(d)))); + + // The upper byte of bf16 is e_bias + (encoded >> (is_small_e ? 3 : 4)). + const hn::Vec shr_3_or_4 = hn::ShiftRight<4>(shl1_if_small); + // .. except when encoded=0: hi = 0, and lo is already 0. + const hn::Vec e7 = hn::MaskedAddOr(k0, is_nonzero, e_bias, shr_3_or_4); + HWY_DASSERT(hn::AllTrue(d, hn::Lt(e7, Set(d, 64u)))); // <= 0x3F + // .. also insert the sign bit. + hi = hn::BitwiseIfThenElse(k80, sign_in_msb, e7); } // Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 150df56..68f3e14 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -101,6 +101,29 @@ void TestAllUnique() { } } +// For deriving the new shift-based decoder, which is 3 ops faster than the +// previous "assemble from binary32 bits" method. +void TestAllFastDecode() { + for (size_t sfp = 0; sfp < 128; ++sfp) { + const float f = F32FromSFP8(sfp); + const uint32_t u = hwy::BitCastScalar(f); + const uint32_t lo = (u >> 16) & 0xFF; + const uint32_t hi = u >> 24; + const bool is_small = sfp < 0x40; + const uint32_t base = is_small ? 0x34 : 0x38; + const uint32_t fast_lo = (sfp << (is_small ? 5 : 4)) & 0xFF; + uint32_t fast_hi = base + (sfp >> (is_small ? 3 : 4)); + if (sfp == 0) fast_hi = 0; + + // fprintf(stderr, "sfp %2zx -> %6.3E %x %x\n", sfp, f, lo, hi); + if (fast_lo != lo || fast_hi != hi) { + HWY_ABORT( + "mismatch sfp %2zx -> %6.3E lo %2x fastLo %2x hi %2x fastHi %2x\n", + sfp, f, lo, fast_lo, hi, fast_hi); + } + } +} + // ------------------------------ Foreach compressed representation // Encode @@ -550,6 +573,7 @@ namespace gcpp { HWY_BEFORE_TEST(SfpTest); HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique); +HWY_EXPORT_AND_TEST_P(SfpTest, TestAllFastDecode); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden); HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);