mirror of https://github.com/google/gemma.cpp.git
SFP speedup: 1.14x f32, 1.19x bf16 dot = 1.02x prefill
12->9 ops by recognizing the upper/lower bytes are simply shifted. PiperOrigin-RevId: 659609241
This commit is contained in:
parent
1982a6ba00
commit
1617e1a33d
|
|
@ -218,7 +218,7 @@ class SfpCodec {
|
|||
#define SFP_IF_GENERIC_DEC(D) void* yes = nullptr
|
||||
#endif
|
||||
|
||||
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
|
||||
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 9 ops.
|
||||
template <class D, HWY_IF_U8_D(D), SFP_IF_GENERIC_DEC(D)>
|
||||
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
|
||||
hn::Vec<D>& hi) {
|
||||
|
|
@ -233,31 +233,27 @@ class SfpCodec {
|
|||
// Special-case zero, negated so we can use MaskedAddOr. Signed comparison
|
||||
// is fine because we have cleared the sign bit.
|
||||
const hn::Mask<D> is_nonzero = SignedGt(d, encoded, k0);
|
||||
// If MSB is clear, we have two mantissa bits, otherwise three.
|
||||
// If bit 6 is clear, we have two mantissa bits, otherwise three.
|
||||
const hn::Mask<D> is_small_e = SignedLt(d, encoded, hn::Set(d, 64));
|
||||
// If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm.
|
||||
const hn::Vec<D> e4m3 =
|
||||
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80)));
|
||||
const hn::Vec<D> e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u))));
|
||||
// The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the
|
||||
// binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest
|
||||
// encoded value (0) should be less than the lowest 'large' exponent 2^-7.
|
||||
const hn::Vec<D> e_bias = hn::IfThenElse(
|
||||
is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u));
|
||||
// Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but
|
||||
// we must zero e_bias to get the desired all-zero bf16.
|
||||
const hn::Vec<D> biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e);
|
||||
// The decoded binary32 exponent should be at most 2^0.
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80)));
|
||||
// For encoded in [1, 8), hi = 0x34; encoded = 0x40 => hi = 0x3C including
|
||||
// (encoded >> 4) == 4, so add 0x38.
|
||||
const hn::Vec<D> e_bias =
|
||||
hn::IfThenElse(is_small_e, hn::Set(d, 0x34), hn::Set(d, 0x38));
|
||||
|
||||
// Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa.
|
||||
const hn::Vec<D> m7 = hn::ShiftLeft<4>(e4m3);
|
||||
// Lower byte of bf16 = exponent LSB || mantissa.
|
||||
lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7);
|
||||
// Upper byte of bf16 = sign || lower 7 bits of exponent.
|
||||
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e));
|
||||
// The low byte of bf16 is encoded << (is_small_e ? 5 : 4).
|
||||
const hn::Vec<D> shl1_if_small =
|
||||
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
|
||||
lo = hn::ShiftLeft<4>(shl1_if_small);
|
||||
// Lower 4 bits always zero.
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Eq(hn::And(lo, Set(d, 15u)), hn::Zero(d))));
|
||||
|
||||
// The upper byte of bf16 is e_bias + (encoded >> (is_small_e ? 3 : 4)).
|
||||
const hn::Vec<D> shr_3_or_4 = hn::ShiftRight<4>(shl1_if_small);
|
||||
// .. except when encoded=0: hi = 0, and lo is already 0.
|
||||
const hn::Vec<D> e7 = hn::MaskedAddOr(k0, is_nonzero, e_bias, shr_3_or_4);
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e7, Set(d, 64u)))); // <= 0x3F
|
||||
// .. also insert the sign bit.
|
||||
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, e7);
|
||||
}
|
||||
|
||||
// Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude
|
||||
|
|
|
|||
|
|
@ -101,6 +101,29 @@ void TestAllUnique() {
|
|||
}
|
||||
}
|
||||
|
||||
// For deriving the new shift-based decoder, which is 3 ops faster than the
|
||||
// previous "assemble from binary32 bits" method.
|
||||
void TestAllFastDecode() {
|
||||
for (size_t sfp = 0; sfp < 128; ++sfp) {
|
||||
const float f = F32FromSFP8(sfp);
|
||||
const uint32_t u = hwy::BitCastScalar<uint32_t>(f);
|
||||
const uint32_t lo = (u >> 16) & 0xFF;
|
||||
const uint32_t hi = u >> 24;
|
||||
const bool is_small = sfp < 0x40;
|
||||
const uint32_t base = is_small ? 0x34 : 0x38;
|
||||
const uint32_t fast_lo = (sfp << (is_small ? 5 : 4)) & 0xFF;
|
||||
uint32_t fast_hi = base + (sfp >> (is_small ? 3 : 4));
|
||||
if (sfp == 0) fast_hi = 0;
|
||||
|
||||
// fprintf(stderr, "sfp %2zx -> %6.3E %x %x\n", sfp, f, lo, hi);
|
||||
if (fast_lo != lo || fast_hi != hi) {
|
||||
HWY_ABORT(
|
||||
"mismatch sfp %2zx -> %6.3E lo %2x fastLo %2x hi %2x fastHi %2x\n",
|
||||
sfp, f, lo, fast_lo, hi, fast_hi);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ------------------------------ Foreach compressed representation
|
||||
|
||||
// Encode
|
||||
|
|
@ -550,6 +573,7 @@ namespace gcpp {
|
|||
HWY_BEFORE_TEST(SfpTest);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllFastDecode);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllEncDec);
|
||||
|
|
|
|||
Loading…
Reference in New Issue