Enable even/odd for SFP. Refs #166

Disable it for float32 because there is not enough benefit.

PiperOrigin-RevId: 631788326
This commit is contained in:
Jan Wassenberg 2024-05-08 07:08:25 -07:00 committed by Copybara-Service
parent bacba351d4
commit c5c9fc300c
4 changed files with 183 additions and 26 deletions

View File

@ -58,7 +58,7 @@ struct CompressTraits {};
template <>
struct CompressTraits<float> {
using MatT = float;
static constexpr bool kSupportsEvenOdd = false;
static constexpr bool kSupportsEvenOdd = false; // unnecessary
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
@ -230,7 +230,8 @@ struct CompressTraits<hwy::bfloat16_t> {
static HWY_INLINE float DotEO(
const DF df32, const hwy::bfloat16_t* HWY_RESTRICT in, size_t in_ofs,
const float* HWY_RESTRICT vec_aligned, size_t num) {
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) && (num % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT(num >= (hn::Lanes(df32) * 2) &&
(num % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT((in_ofs % (hn::Lanes(df32) * 2)) == 0);
HWY_DASSERT(hn::IsAligned(df32, vec_aligned));
@ -273,13 +274,13 @@ struct CompressTraits<hwy::bfloat16_t> {
template <>
struct CompressTraits<SfpStream> {
using MatT = SfpStream;
static constexpr bool kSupportsEvenOdd = false;
static constexpr bool kSupportsEvenOdd = true;
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Compress(DF df, const float* in, size_t num,
CompressPerThread& tls,
size_t /*out_capacity*/, MatT* out,
size_t out_ofs) {
static HWY_INLINE void Compress(DF df, const float* HWY_RESTRICT in,
size_t num, CompressPerThread& tls,
size_t /*out_capacity*/,
MatT* HWY_RESTRICT out, size_t out_ofs) {
SfpCodec::Enc(df, in, num, out + out_ofs);
if (COMPRESS_STATS) {
@ -295,15 +296,21 @@ struct CompressTraits<SfpStream> {
}
template <class D, typename OutT>
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/, const MatT* in,
size_t in_ofs, OutT* out, size_t num) {
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
OutT* HWY_RESTRICT out, size_t num) {
SfpCodec::Dec(d, in + in_ofs, num, out);
}
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, const MatT* in,
size_t in_ofs, const VecT* vec_aligned,
static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0);
HWY_DASSERT((in_ofs % hn::Lanes(df)) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using VF = hn::Vec<decltype(df)>;
VF sum0 = hn::Zero(df);
VF sum1 = hn::Zero(df);
@ -318,6 +325,34 @@ struct CompressTraits<SfpStream> {
sum0 = hn::Add(sum0, sum2);
return hn::ReduceSum(df, sum0);
}
// Computes the dot product of an even-odd deinterleaved, f32 or bf16
// `vec_aligned` and a column-major matrix `in`. `vec_aligned` should be
// aligned and alternate even-indexed `hn::Lanes(df)` elements followed by
// odd-indexed `hn::Lanes(df)` elements.
template <class DF, typename VecT, HWY_IF_F32_D(DF)>
static HWY_INLINE float DotEO(const DF df, const MatT* HWY_RESTRICT in,
size_t in_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t num) {
HWY_DASSERT(num >= (hn::Lanes(df) * 2) && (num % (hn::Lanes(df) * 2)) == 0);
HWY_DASSERT((in_ofs % (hn::Lanes(df) * 2)) == 0);
HWY_DASSERT(hn::IsAligned(df, vec_aligned));
using VF = hn::Vec<decltype(df)>;
VF sum0 = hn::Zero(df);
VF sum1 = hn::Zero(df);
VF sum2 = hn::Zero(df);
VF sum3 = hn::Zero(df);
SfpCodec::DotEO(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3);
// Reduction tree: sum of all accumulators, then their lanes
sum0 = hn::Add(sum0, sum1);
sum2 = hn::Add(sum2, sum3);
sum0 = hn::Add(sum0, sum2);
return hn::ReduceSum(df, sum0);
}
};
template <>

View File

@ -475,6 +475,61 @@ class SfpCodec {
}
}
// Fused decode and dot product with even-odd bf16 into four f32 accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num,
const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using V8 = hn::Vec<decltype(d8)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
HWY_DASSERT(num % (2 * N16) == 0); // whole SFP vector -> 2x bf16
HWY_UNROLL(1)
for (size_t i = 0; i < num; i += 2 * N16) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VBF ve = hn::LoadU(dbf, vec_aligned + i);
const VBF vo = hn::LoadU(dbf, vec_aligned + i + N16);
VBF be, bo;
DecEvenOdd(dbf, packed, be, bo);
sum0 = hn::ReorderWidenMulAccumulate(df, be, ve, sum0, sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, bo, vo, sum2, sum3);
}
}
// Fused decode and dot product with even-odd f32 into four f32 accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num,
const float* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
HWY_DASSERT(num % (4 * NF) == 0); // whole SFP vector -> 4x f32
HWY_UNROLL(1)
for (size_t i = 0; i < num; i += 4 * NF) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VF ve0 = hn::LoadU(df, vec_aligned + i + NF * 0);
const VF vo0 = hn::LoadU(df, vec_aligned + i + NF * 1);
const VF ve1 = hn::LoadU(df, vec_aligned + i + NF * 2);
const VF vo1 = hn::LoadU(df, vec_aligned + i + NF * 3);
VF fe0, fo0, fe1, fo1;
DecEvenOddF(df, packed, fe0, fo0, fe1, fo1);
sum0 = hn::MulAdd(fe0, ve0, sum0);
sum1 = hn::MulAdd(fo0, vo0, sum1);
sum2 = hn::MulAdd(fe1, ve1, sum2);
sum3 = hn::MulAdd(fo1, vo1, sum3);
}
}
private:
// Wrappers to avoid code duplication across float/bf16 input types and
// the main loop/remainder.
@ -574,6 +629,37 @@ class SfpCodec {
f2 = hn::PromoteLowerTo(df, bf1);
f3 = hn::PromoteUpperTo(df, bf1);
}
template <class DBF, HWY_IF_BF16_D(DBF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
static HWY_INLINE void DecEvenOdd(DBF dbf, V8 packed, hn::Vec<DBF>& even,
hn::Vec<DBF>& odd) {
const hn::Repartition<uint8_t, DBF> d8;
V8 lo, hi;
DecBytes(d8, packed, lo, hi);
#if HWY_MAJOR > 1 || HWY_MINOR >= 2
even = hn::BitCast(dbf, hn::InterleaveEven(d8, lo, hi));
odd = hn::BitCast(dbf, hn::InterleaveOdd(d8, lo, hi));
#else
even = hn::BitCast(dbf, hn::OddEven(hn::DupEven(hi), lo));
odd = hn::BitCast(dbf, hn::OddEven(hi, hn::DupOdd(lo)));
#endif
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE void DecEvenOddF(DF df, V8 packed, hn::Vec<DF>& even0,
hn::Vec<DF>& odd0, hn::Vec<DF>& even1,
hn::Vec<DF>& odd1) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
VBF even_bf, odd_bf;
DecEvenOdd(dbf, packed, even_bf, odd_bf);
even0 = hn::PromoteLowerTo(df, even_bf);
odd0 = hn::PromoteLowerTo(df, odd_bf);
even1 = hn::PromoteUpperTo(df, even_bf);
odd1 = hn::PromoteUpperTo(df, odd_bf);
}
}; // SfpCodec
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -401,11 +401,13 @@ struct TestDot {
HWY_INLINE void operator()(T /*unused*/, D d) {
const hn::Repartition<float, D> df;
const size_t num = 1024; // not too many for GeometricMean overflow.
const size_t N = hn::Lanes(d);
auto in = hwy::AllocateAligned<T>(num);
auto dec = hwy::AllocateAligned<T>(num);
auto vec = hwy::AllocateAligned<T>(num);
auto vec_eo = hwy::AllocateAligned<T>(num);
auto sfp = hwy::AllocateAligned<SfpStream>(num);
HWY_ASSERT(in && dec && vec && sfp);
HWY_ASSERT(in && dec && vec && vec_eo && sfp);
// Generate inputs and verify their distribution.
hwy::RandomState rng;
@ -422,12 +424,23 @@ struct TestDot {
}
VerifyGaussian(in_stats);
// Convert vec to even/odd for DotEO
for (size_t i = 0; i < num; i += 2 * N) {
hn::Vec<D> ve, vo;
hn::LoadInterleaved2(d, vec.get() + i, ve, vo);
hn::Store(ve, d, vec_eo.get() + i + 0);
hn::Store(vo, d, vec_eo.get() + i + N);
}
SfpCodec::Enc(d, in.get(), num, sfp.get());
// Compute dot product without decompression.
float actual = 0.0f;
float actual_eo = 0.0f;
double elapsed = hwy::HighestValue<double>();
double elapsed_eo = hwy::HighestValue<double>();
for (size_t rep = 0; rep < 200; ++rep) {
{
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
@ -439,10 +452,26 @@ struct TestDot {
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
actual = hn::ReduceSum(df, sum0);
}
{
hn::Vec<decltype(df)> sum0 = hn::Zero(df);
hn::Vec<decltype(df)> sum1 = hn::Zero(df);
hn::Vec<decltype(df)> sum2 = hn::Zero(df);
hn::Vec<decltype(df)> sum3 = hn::Zero(df);
const double t0 = hwy::platform::Now();
SfpCodec::DotEO(df, sfp.get(), num, vec_eo.get(), sum0, sum1, sum2,
sum3);
const double t1 = hwy::platform::Now();
elapsed_eo = HWY_MIN(elapsed_eo, t1 - t0);
sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3));
actual_eo = hn::ReduceSum(df, sum0);
}
}
SfpCodec::Dec(d, sfp.get(), num, dec.get());
fprintf(stderr, "Vec %zu Dot %zu-bit %.2f MB/s\n", Lanes(d) * sizeof(T),
sizeof(T) * 8, num * sizeof(T) * 1E-6 / elapsed);
fprintf(stderr, "Vec %zu Dot %zu-bit %.2f ; %.2f MB/s\n",
Lanes(d) * sizeof(T), sizeof(T) * 8,
num * sizeof(T) * 1E-6 / elapsed,
num * sizeof(T) * 1E-6 / elapsed_eo);
// Exact and decompressed dot products for comparison.
float exact = 0.0f; // using original input
@ -479,6 +508,8 @@ struct TestDot {
HWY_ASSERT(gcpp::IsInside(0.87f, 1.0f, final_ratio));
// Decompressed and uncompressed dot should match exactly.
HWY_ASSERT(gcpp::IsNear(expected, actual, 1E-4f));
// Even/odd dot should also match
HWY_ASSERT(gcpp::IsNear(actual, actual_eo, 1E-4f));
// Geomean of ratios for each i should be very close to one.
HWY_ASSERT(dot_snr >= (isBF ? 70.0 : 1000.0));

View File

@ -25,6 +25,7 @@
#include <random>
#include <type_traits> // std::enable_if_t
#include "compression/sfp.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
@ -286,8 +287,12 @@ HWY_INLINE void MatVecT(const ArrayT& mat, const size_t mat_ofs,
PROFILER_ZONE("MatVecAdd");
#if !defined(HWY_NATIVE_DOT_BF16) || !HWY_NATIVE_DOT_BF16
if constexpr (CompressTraits<typename ArrayT::value_type>::kSupportsEvenOdd &&
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>()) {
using MatT = typename ArrayT::value_type;
// Sfp -> float does not benefit enough to recoup the cost of ToEvenOddF32.
if constexpr (CompressTraits<MatT>::kSupportsEvenOdd &&
hwy::IsSameEither<VecT, float, hwy::bfloat16_t>() &&
!(hwy::IsSame<MatT, SfpStream>() &&
hwy::IsSame<VecT, float>())) {
ToEvenOddF32(vec_aligned, kInner, even_odd);
detail::MatVecAddInner</*kVecIsEvenOdd=*/true, kAdd, kOuter, kInner>(
mat, mat_ofs, even_odd, add, out, pool);