mirror of https://github.com/google/gemma.cpp.git
2x speedup of SFP decode (1.4x overall) on AVX3_DL+.
Thanks @nzmichaelh for suggesting table lookups! PiperOrigin-RevId: 631337524
This commit is contained in:
parent
18f6d43fcc
commit
b5a9ade75f
|
|
@ -33,6 +33,7 @@
|
|||
#define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
|
||||
#endif
|
||||
|
||||
#include "hwy/detect_targets.h"
|
||||
#include "hwy/highway.h"
|
||||
|
||||
HWY_BEFORE_NAMESPACE();
|
||||
|
|
@ -157,9 +158,65 @@ class SfpCodec {
|
|||
return hn::IfThenZeroElse(is_zero, encoded);
|
||||
}
|
||||
|
||||
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 3 ops (AVX-512).
|
||||
#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE
|
||||
template <class D, HWY_IF_U8_D(D), HWY_IF_V_SIZE_D(D, 64)>
|
||||
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
|
||||
hn::Vec<D>& hi) {
|
||||
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
|
||||
HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved
|
||||
|
||||
// Two 2x64 table lookups for lo/hi.
|
||||
alignas(64) static constexpr uint8_t kTblL0[64] = {
|
||||
0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40,
|
||||
0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0,
|
||||
0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00,
|
||||
0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60,
|
||||
0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0,
|
||||
0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0};
|
||||
alignas(64) static constexpr uint8_t kTblL1[64] = {
|
||||
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0,
|
||||
0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50,
|
||||
0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00,
|
||||
0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0,
|
||||
0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60,
|
||||
0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0};
|
||||
alignas(64) static constexpr uint8_t kTblH0[64] = {
|
||||
0x00, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x35, 0x35, 0x35,
|
||||
0x35, 0x35, 0x35, 0x35, 0x35, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
|
||||
0x36, 0x36, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x38,
|
||||
0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x39, 0x39, 0x39, 0x39,
|
||||
0x39, 0x39, 0x39, 0x39, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A,
|
||||
0x3A, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B};
|
||||
alignas(64) static constexpr uint8_t kTblH1[64] = {
|
||||
0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C,
|
||||
0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D,
|
||||
0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3E,
|
||||
0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E,
|
||||
0x3E, 0x3E, 0x3E, 0x3E, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F,
|
||||
0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F};
|
||||
const hn::Vec<D> tblL0 = hn::LoadU(d, kTblL0);
|
||||
const hn::Vec<D> tblL1 = hn::LoadU(d, kTblL1);
|
||||
const hn::Vec<D> tblH0 = hn::LoadU(d, kTblH0);
|
||||
const hn::Vec<D> tblH1 = hn::LoadU(d, kTblH1);
|
||||
// AVX-512 ignores the index MSB, no need to clear.
|
||||
const hn::Indices512<uint8_t> idx{encoded.raw};
|
||||
hi = hn::TwoTablesLookupLanes(d, tblH0, tblH1, idx);
|
||||
lo = hn::TwoTablesLookupLanes(d, tblL0, tblL1, idx);
|
||||
hi = hn::OrAnd(hi, encoded, k80); // Insert sign bit
|
||||
}
|
||||
|
||||
// Generic is only required for partial vectors (too small for tables).
|
||||
#undef SFP_IF_GENERIC_DEC
|
||||
#define SFP_IF_GENERIC_DEC(D) HWY_IF_V_SIZE_LE_D(D, 32)
|
||||
#else
|
||||
// Always enable the generic decoder.
|
||||
#undef SFP_IF_GENERIC_DEC
|
||||
#define SFP_IF_GENERIC_DEC(D) void* yes = nullptr
|
||||
#endif
|
||||
|
||||
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
|
||||
// Implementation detail, public because called by test.
|
||||
template <class D, HWY_IF_U8_D(D)>
|
||||
template <class D, HWY_IF_U8_D(D), SFP_IF_GENERIC_DEC(D)>
|
||||
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
|
||||
hn::Vec<D>& hi) {
|
||||
const hn::Vec<D> k0 = hn::Zero(d);
|
||||
|
|
|
|||
|
|
@ -67,6 +67,26 @@ float F32FromSFP8(uint32_t sfp) {
|
|||
return result;
|
||||
}
|
||||
|
||||
// Used for HWY_AVX3_DL and newer.
|
||||
void PrintTables() {
|
||||
if (HWY_ONCE && false) {
|
||||
uint8_t hi[128];
|
||||
fprintf(stderr, "lo\n");
|
||||
for (uint32_t sfp = 0; sfp < 128; ++sfp) {
|
||||
const uint32_t u = hwy::BitCastScalar<uint32_t>(F32FromSFP8(sfp));
|
||||
// Lower bits are zero, hence we can truncate instead of rounding to bf16.
|
||||
HWY_ASSERT((u & 0xFFFF) == 0);
|
||||
fprintf(stderr, "0x%02X,", (u >> 16) & 0xFF);
|
||||
hi[sfp] = u >> 24;
|
||||
}
|
||||
fprintf(stderr, "\nhi\n");
|
||||
for (uint32_t sfp = 0; sfp < 128; ++sfp) {
|
||||
fprintf(stderr, "0x%02X,", hi[sfp]);
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
}
|
||||
|
||||
void TestAllUnique() {
|
||||
std::set<float> unique;
|
||||
for (uint32_t sfp = 0; sfp < 256; ++sfp) {
|
||||
|
|
@ -497,6 +517,7 @@ HWY_AFTER_NAMESPACE();
|
|||
|
||||
namespace gcpp {
|
||||
HWY_BEFORE_TEST(SfpTest);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, PrintTables);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllUnique);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllDecEnc);
|
||||
HWY_EXPORT_AND_TEST_P(SfpTest, TestAllGolden);
|
||||
|
|
|
|||
Loading…
Reference in New Issue