2x speedup of SFP decode (1.4x overall) on AVX3_DL+.

Thanks @nzmichaelh for suggesting table lookups!

PiperOrigin-RevId: 631337524
This commit is contained in:
Jan Wassenberg 2024-05-07 01:46:07 -07:00 committed by Copybara-Service
parent 18f6d43fcc
commit b5a9ade75f
2 changed files with 80 additions and 2 deletions

View File

@ -33,6 +33,7 @@
#define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#endif
#include "hwy/detect_targets.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
@ -157,9 +158,65 @@ class SfpCodec {
return hn::IfThenZeroElse(is_zero, encoded);
}
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 3 ops (AVX-512).
#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE
template <class D, HWY_IF_U8_D(D), HWY_IF_V_SIZE_D(D, 64)>
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
hn::Vec<D>& hi) {
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved
// Two 2x64 table lookups for lo/hi.
alignas(64) static constexpr uint8_t kTblL0[64] = {
0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40,
0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0,
0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00,
0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60,
0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0,
0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0};
alignas(64) static constexpr uint8_t kTblL1[64] = {
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0,
0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50,
0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00,
0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0,
0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60,
0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0};
alignas(64) static constexpr uint8_t kTblH0[64] = {
0x00, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x35, 0x35, 0x35,
0x35, 0x35, 0x35, 0x35, 0x35, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
0x36, 0x36, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x38,
0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x39, 0x39, 0x39, 0x39,
0x39, 0x39, 0x39, 0x39, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A,
0x3A, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B};
alignas(64) static constexpr uint8_t kTblH1[64] = {
0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C,
0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D,
0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3E,
0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E,
0x3E, 0x3E, 0x3E, 0x3E, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F,
0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F};
const hn::Vec<D> tblL0 = hn::LoadU(d, kTblL0);
const hn::Vec<D> tblL1 = hn::LoadU(d, kTblL1);
const hn::Vec<D> tblH0 = hn::LoadU(d, kTblH0);
const hn::Vec<D> tblH1 = hn::LoadU(d, kTblH1);
// AVX-512 ignores the index MSB, no need to clear.
const hn::Indices512<uint8_t> idx{encoded.raw};
hi = hn::TwoTablesLookupLanes(d, tblH0, tblH1, idx);
lo = hn::TwoTablesLookupLanes(d, tblL0, tblL1, idx);
hi = hn::OrAnd(hi, encoded, k80); // Insert sign bit
}
// Generic is only required for partial vectors (too small for tables).
#undef SFP_IF_GENERIC_DEC
#define SFP_IF_GENERIC_DEC(D) HWY_IF_V_SIZE_LE_D(D, 32)
#else
// Always enable the generic decoder.
#undef SFP_IF_GENERIC_DEC
#define SFP_IF_GENERIC_DEC(D) void* yes = nullptr
#endif
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
// Implementation detail, public because called by test.
template <class D, HWY_IF_U8_D(D)>
template <class D, HWY_IF_U8_D(D), SFP_IF_GENERIC_DEC(D)>
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
hn::Vec<D>& hi) {
const hn::Vec<D> k0 = hn::Zero(d);

View File

@ -67,6 +67,26 @@ float F32FromSFP8(uint32_t sfp) {
return result;
}
// Used for HWY_AVX3_DL and newer.
void PrintTables() {
if (HWY_ONCE && false) {
uint8_t hi[128];
fprintf(stderr, "lo\n");
for (uint32_t sfp = 0; sfp < 128; ++sfp) {
const uint32_t u = hwy::BitCastScalar<uint32_t>(F32FromSFP8(sfp));
// Lower bits are zero, hence we can truncate instead of rounding to bf16.
HWY_ASSERT((u & 0xFFFF) == 0);
fprintf(stderr, "0x%02X,", (u >> 16) & 0xFF);
hi[sfp] = u >> 24;
}
fprintf(stderr, "\nhi\n");
for (uint32_t sfp = 0; sfp < 128; ++sfp) {
fprintf(stderr, "0x%02X,", hi[sfp]);
}
fprintf(stderr, "\n");
}
}
void TestAllUnique() {
std::set<float> unique;
for (uint32_t sfp = 0; sfp < 256; ++sfp) {
@ -497,6 +517,7 @@ HWY_AFTER_NAMESPACE();
namespace gcpp {
HWY_BEFORE_TEST(SfpTest);
HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);