SFP speedup: 1.14x f32, 1.19x bf16 dot = 1.02x prefill

12->9 ops by recognizing the upper/lower bytes are simply shifted.

PiperOrigin-RevId: 659609241
This commit is contained in:
Jan Wassenberg 2024-08-05 10:58:34 -07:00 committed by Copybara-Service
parent 1982a6ba00
commit 1617e1a33d
2 changed files with 44 additions and 24 deletions

View File

@ -218,7 +218,7 @@ class SfpCodec {
#define SFP_IF_GENERIC_DEC(D) void* yes = nullptr
#endif
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 9 ops.
template <class D, HWY_IF_U8_D(D), SFP_IF_GENERIC_DEC(D)>
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
hn::Vec<D>& hi) {
@ -233,31 +233,27 @@ class SfpCodec {
// Special-case zero, negated so we can use MaskedAddOr. Signed comparison
// is fine because we have cleared the sign bit.
const hn::Mask<D> is_nonzero = SignedGt(d, encoded, k0);
// If MSB is clear, we have two mantissa bits, otherwise three.
// If bit 6 is clear, we have two mantissa bits, otherwise three.
const hn::Mask<D> is_small_e = SignedLt(d, encoded, hn::Set(d, 64));
// If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm.
const hn::Vec<D> e4m3 =
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80)));
const hn::Vec<D> e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u))));
// The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the
// binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest
// encoded value (0) should be less than the lowest 'large' exponent 2^-7.
const hn::Vec<D> e_bias = hn::IfThenElse(
is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u));
// Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but
// we must zero e_bias to get the desired all-zero bf16.
const hn::Vec<D> biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e);
// The decoded binary32 exponent should be at most 2^0.
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80)));
// For encoded in [1, 8), hi = 0x34; encoded = 0x40 => hi = 0x3C including
// (encoded >> 4) == 4, so add 0x38.
const hn::Vec<D> e_bias =
hn::IfThenElse(is_small_e, hn::Set(d, 0x34), hn::Set(d, 0x38));
// Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa.
const hn::Vec<D> m7 = hn::ShiftLeft<4>(e4m3);
// Lower byte of bf16 = exponent LSB || mantissa.
lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7);
// Upper byte of bf16 = sign || lower 7 bits of exponent.
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e));
// The low byte of bf16 is encoded << (is_small_e ? 5 : 4).
const hn::Vec<D> shl1_if_small =
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
lo = hn::ShiftLeft<4>(shl1_if_small);
// Lower 4 bits always zero.
HWY_DASSERT(hn::AllTrue(d, hn::Eq(hn::And(lo, Set(d, 15u)), hn::Zero(d))));
// The upper byte of bf16 is e_bias + (encoded >> (is_small_e ? 3 : 4)).
const hn::Vec<D> shr_3_or_4 = hn::ShiftRight<4>(shl1_if_small);
// .. except when encoded=0: hi = 0, and lo is already 0.
const hn::Vec<D> e7 = hn::MaskedAddOr(k0, is_nonzero, e_bias, shr_3_or_4);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e7, Set(d, 64u)))); // <= 0x3F
// .. also insert the sign bit.
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, e7);
}
// Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude

View File

@ -101,6 +101,29 @@ void TestAllUnique() {
}
}
// For deriving the new shift-based decoder, which is 3 ops faster than the
// previous "assemble from binary32 bits" method.
void TestAllFastDecode() {
for (size_t sfp = 0; sfp < 128; ++sfp) {
const float f = F32FromSFP8(sfp);
const uint32_t u = hwy::BitCastScalar<uint32_t>(f);
const uint32_t lo = (u >> 16) & 0xFF;
const uint32_t hi = u >> 24;
const bool is_small = sfp < 0x40;
const uint32_t base = is_small ? 0x34 : 0x38;
const uint32_t fast_lo = (sfp << (is_small ? 5 : 4)) & 0xFF;
uint32_t fast_hi = base + (sfp >> (is_small ? 3 : 4));
if (sfp == 0) fast_hi = 0;
// fprintf(stderr, "sfp %2zx -> %6.3E %x %x\n", sfp, f, lo, hi);
if (fast_lo != lo || fast_hi != hi) {
HWY_ABORT(
"mismatch sfp %2zx -> %6.3E lo %2x fastLo %2x hi %2x fastHi %2x\n",
sfp, f, lo, fast_lo, hi, fast_hi);
}
}
}
// ------------------------------ Foreach compressed representation
// Encode
@ -550,6 +573,7 @@ namespace gcpp {
HWY_BEFORE_TEST(SfpTest);
HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllFastDecode);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);