gemma.cpp/compression/sfp-inl.h

684 lines
29 KiB
C++

// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Normal include guard to placate lint.
#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
#include <stddef.h>
#include <stdint.h>
#include "compression/sfp.h"
#include "hwy/base.h"
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_
// Actual per-target include guard.
#if defined(THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE) == defined(HWY_TARGET_TOGGLE)
#ifdef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#undef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#else
#define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE
#endif
#include "hwy/detect_targets.h"
#include "hwy/highway.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
// For unsigned numbers with MSB zero, signed comparison is faster on x86.
template <class DU>
HWY_INLINE hn::Mask<DU> SignedGt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
const hn::RebindToSigned<DU> di;
return hn::RebindMask(du, hn::Gt(BitCast(di, a), hn::BitCast(di, b)));
}
template <class DU>
HWY_INLINE hn::Mask<DU> SignedLt(DU du, hn::Vec<DU> a, hn::Vec<DU> b) {
return SignedGt(du, b, a);
}
// Encode/decode functions.
class SfpCodec {
public:
// Returns 8-bit packed representation of `lo` and `hi` bytes of bf16. 31 ops.
// Implementation detail, public because called by test.
template <class D, HWY_IF_U8_D(D)>
static HWY_INLINE hn::Vec<D> EncBytes(D d, const hn::Vec<D> lo,
const hn::Vec<D> hi) {
const hn::Vec<D> k1 = hn::Set(d, 1u);
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
// Copy sign for later insertion.
const hn::Vec<D> sign_in_msb = hi;
// Biased exponent = lower 7 bits of hi and MSB of lo. Modified below.
hn::Vec<D> biased_e = hn::Or(hn::Add(hi, hi), hn::ShiftRight<7>(lo));
HWY_ASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); // <= 2^0
// Clear MSB to isolate the mantissa and enable signed comparisons, then
// shift right by *one* (plus 1 to undo the prior add/left-shift) to leave
// headroom for overflow during rounding.
const hn::Vec<D> m6 = hn::ShiftRight<2>(hn::Add(lo, lo));
// The place to round depends on whether the exponent is large (>= -7) - if
// so, we retain three mantissa bits, otherwise two. However, rounding can
// also cause the exponent to increase. We first choose a threshold that
// rounds up to 1.0*2^-7 for both two and three bit mantissas:
// >= 1.1111 * 2^-8 (0.007568359375). This entails the exponent being
// greater, or equal and the mantissa > (1111000 >> 1) - 1 = 0x3B.
const hn::Vec<D> kMinLargeE = hn::Set(d, 127 - 8);
const hn::Mask<D> is_large_before_round = hn::Or(
SignedGt(d, biased_e, kMinLargeE),
hn::And(hn::Eq(biased_e, kMinLargeE), SignedGt(d, m6, Set(d, 0x3B))));
// To retain the most-significant 3 or 2 mantissa bits, we will right-shift
// by is_large_before_round ? 3 : 4. Variable Shr is expensive for 8-bit
// elements, so (<< 1) if is_large_before_round, then always (>> 4).
const hn::Vec<D> m_shl4 =
hn::MaskedAddOr(m6, is_large_before_round, m6, m6);
// Before shifting (truncation), round to nearest even to reduce bias. If
// the lowest remaining mantissa bit is odd, increase the offset. Example
// with the lowest remaining bit (left) and next lower two bits; the
// latter, plus two more, will be truncated.
// 0[00] + 1 = 0[01]
// 0[01] + 1 = 0[10]
// 0[10] + 1 = 0[11] (round down toward even)
// 0[11] + 1 = 1[00] (round up)
// 1[00] + 10 = 1[10]
// 1[01] + 10 = 1[11]
// 1[10] + 10 = C0[00] (round up toward even with C=1 carry out)
// 1[11] + 10 = C0[01] (round up toward even with C=1 carry out)
const hn::Vec<D> odd_bit = hn::And(hn::ShiftRight<4>(m_shl4), k1);
const hn::Vec<D> rounded = hn::Add(m_shl4, hn::Add(odd_bit, Set(d, 7)));
// Update the exponent if rounding overflowed.
const hn::Vec<D> carry_bit =
hn::IfThenElse(is_large_before_round, k80, hn::Set(d, 0x40u));
const hn::Vec<D> carry_clear = hn::AndNot(carry_bit, rounded);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(carry_clear, carry_bit)));
const hn::Mask<D> is_overflow = hn::Ne(carry_clear, rounded);
biased_e = hn::MaskedAddOr(biased_e, is_overflow, biased_e, k1);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, Set(d, 128))));
// Detect if zero or the min exponent.
const hn::Vec<D> kMinNormal = hn::Set(d, 127 - 23);
const hn::Mask<D> is_zero = SignedLt(d, biased_e, kMinNormal);
const hn::Mask<D> is_min = hn::Eq(biased_e, kMinNormal);
// 1.1110xxx * 2^-8 was considered small above, and thus rounded up to 2^-7,
// which the decoder will consider large, and expect 3 mantissa bits. If we
// set the threshold above to 1.111, then it does NOT round up. Thus we
// check exponent >= -7 *after* rounding.
const hn::Mask<D> is_large = SignedGt(d, biased_e, hn::Set(d, 127 - 8));
// To extract and pack the mantissa, only is_large matters. Either it
// matches is_large_before_round, or the rounding resulted in mantissa=0, so
// we either extract two or three bits by shifting out the lower 5..6 bits.
// is_large_before is_large rounded want
// 0 0 0Cmm???? mm
// 0 1 0100???? 000
// 1 0 impossible -
// 1 1 Cmmm???0 mmm
hn::Vec<D> m = hn::ShiftRight<4>(carry_clear);
HWY_DASSERT(hn::AllTrue(
d, SignedLt(d, m,
hn::IfThenElse(is_large, hn::Set(d, 8), hn::Set(d, 4)))));
// 1.0 * 2^-23 has the same encoding as zero, so round it up to 1.01.
m = hn::MaskedMaxOr(m, is_min, m, k1);
const hn::Vec<D> e_bias = hn::IfThenElse(
is_large,
hn::Set(d, hwy::BitCastScalar<uint8_t>(static_cast<int8_t>(15 - 127))),
hn::Set(d, hwy::BitCastScalar<uint8_t>(static_cast<int8_t>(23 - 127))));
const hn::Vec<D> e = hn::Add(biased_e, e_bias);
HWY_DASSERT(
hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, e), hn::Set(d, 16))));
// Shift exponent left 2 or 3 bits to make space for `m`.
const hn::Vec<D> em =
hn::Or(m, hn::ShiftLeft<2>(hn::MaskedAddOr(e, is_large, e, e)));
HWY_DASSERT(hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, em), k80)));
const hn::Vec<D> encoded = hn::BitwiseIfThenElse(k80, sign_in_msb, em);
// Doing this last ensures -0 is replaced with 0.
return hn::IfThenZeroElse(is_zero, encoded);
}
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 3 ops (AVX-512).
#if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE
template <class D, HWY_IF_U8_D(D), HWY_IF_V_SIZE_D(D, 64)>
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
hn::Vec<D>& hi) {
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved
// Two 2x64 table lookups for lo/hi.
alignas(64) static constexpr uint8_t kTblL0[64] = {
0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40,
0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0,
0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00,
0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60,
0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0,
0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0};
alignas(64) static constexpr uint8_t kTblL1[64] = {
0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0,
0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50,
0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00,
0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0,
0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60,
0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0};
alignas(64) static constexpr uint8_t kTblH0[64] = {
0x00, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x35, 0x35, 0x35,
0x35, 0x35, 0x35, 0x35, 0x35, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
0x36, 0x36, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x38,
0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x39, 0x39, 0x39, 0x39,
0x39, 0x39, 0x39, 0x39, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A,
0x3A, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B};
alignas(64) static constexpr uint8_t kTblH1[64] = {
0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C,
0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D,
0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3E,
0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E,
0x3E, 0x3E, 0x3E, 0x3E, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F,
0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F};
const hn::Vec<D> tblL0 = hn::LoadU(d, kTblL0);
const hn::Vec<D> tblL1 = hn::LoadU(d, kTblL1);
const hn::Vec<D> tblH0 = hn::LoadU(d, kTblH0);
const hn::Vec<D> tblH1 = hn::LoadU(d, kTblH1);
#if HWY_IDE // only let the IDE see portable code.
const auto idx = hn::IndicesFromVec(hn::AndNot(k80, encoded));
#else // AVX-512-specific: index MSB is ignored, no need to clear.
const hn::Indices512<uint8_t> idx{encoded.raw};
#endif
hi = hn::TwoTablesLookupLanes(d, tblH0, tblH1, idx);
lo = hn::TwoTablesLookupLanes(d, tblL0, tblL1, idx);
hi = hn::OrAnd(hi, encoded, k80); // Insert sign bit
}
// Generic is only required for partial vectors (too small for tables).
#undef SFP_IF_GENERIC_DEC
#define SFP_IF_GENERIC_DEC(D) HWY_IF_V_SIZE_LE_D(D, 32)
#else
// Always enable the generic decoder.
#undef SFP_IF_GENERIC_DEC
#define SFP_IF_GENERIC_DEC(D) void* yes = nullptr
#endif
// Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 12 ops.
template <class D, HWY_IF_U8_D(D), SFP_IF_GENERIC_DEC(D)>
static HWY_INLINE void DecBytes(D d, hn::Vec<D> encoded, hn::Vec<D>& lo,
hn::Vec<D>& hi) {
const hn::Vec<D> k0 = hn::Zero(d);
const hn::Vec<D> k80 = hn::Set(d, 0x80u);
HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved
// Copy sign for later insertion via BitwiseIfThenElse.
const hn::Vec<D> sign_in_msb = encoded;
encoded = hn::AndNot(k80, encoded);
// Special-case zero, negated so we can use MaskedAddOr. Signed comparison
// is fine because we have cleared the sign bit.
const hn::Mask<D> is_nonzero = SignedGt(d, encoded, k0);
// If MSB is clear, we have two mantissa bits, otherwise three.
const hn::Mask<D> is_small_e = SignedLt(d, encoded, hn::Set(d, 64));
// If is_small_e, add/left-shift 0xxxx.mm to 0xxxx.mm0; else keep 1xxx.mmm.
const hn::Vec<D> e4m3 =
hn::MaskedAddOr(encoded, is_small_e, encoded, encoded);
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e4m3, k80)));
const hn::Vec<D> e = hn::ShiftRight<3>(e4m3); // 4-bit exponent only
HWY_DASSERT(hn::AllTrue(d, hn::Lt(e, Set(d, 16u))));
// The encoded exponent for 2^0 is 15, so subtract 15. Add 127 for the
// binary32/bf16 bias. Subtract another 8 if is_small_e because its lowest
// encoded value (0) should be less than the lowest 'large' exponent 2^-7.
const hn::Vec<D> e_bias = hn::IfThenElse(
is_small_e, hn::Set(d, 127u - 15u - 8u), hn::Set(d, 127u - 15u));
// Special-case zero or add e_bias. If encoded=0, e and e4m3 are zero, but
// we must zero e_bias to get the desired all-zero bf16.
const hn::Vec<D> biased_e = hn::MaskedAddOr(k0, is_nonzero, e_bias, e);
// The decoded binary32 exponent should be at most 2^0.
HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80)));
// Shift the MSB of e4m3's mantissa into the MSB of the bf16 mantissa.
const hn::Vec<D> m7 = hn::ShiftLeft<4>(e4m3);
// Lower byte of bf16 = exponent LSB || mantissa.
lo = hn::BitwiseIfThenElse(k80, hn::ShiftLeft<7>(biased_e), m7);
// Upper byte of bf16 = sign || lower 7 bits of exponent.
hi = hn::BitwiseIfThenElse(k80, sign_in_msb, hn::ShiftRight<1>(biased_e));
}
// Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude
// must be at most 1.875.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf,
size_t num, SfpStream* HWY_RESTRICT out_packed) {
const hn::Repartition<uint8_t, DBF> d8;
using V8 = hn::Vec<decltype(d8)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= 2 * N16) {
HWY_UNROLL(1)
for (; i <= num - 2 * N16; i += 2 * N16) {
const V8 packed = Enc2B(dbf, in_bf + i);
hn::StoreU(packed, d8, &out_packed->byte + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * N16);
if (remaining != 0) {
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0]));
const V8 packed = Enc2B(dbf, padded);
hn::StoreN(packed, d8, &out_packed->byte + i, remaining);
}
}
// Encodes `num` f32 values from `in_f` to `packed`. Their magnitude
// must be at most 1.875.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num,
SfpStream* HWY_RESTRICT out_packed) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 4 * NF) {
HWY_UNROLL(1)
for (; i <= num - 4 * NF; i += 4 * NF) {
const V8 packed = Enc4F(df, in_f + i);
hn::StoreU(packed, d8, &out_packed->byte + i);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 4 * NF);
if (remaining != 0) {
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(in_f + i, padded, remaining * sizeof(padded[0]));
const V8 packed = Enc4F(df, padded);
hn::StoreN(packed, d8, &out_packed->byte + i, remaining);
}
}
// Decodes `num` values from `in_packed` to `out_bf`.
template <class DBF, HWY_IF_BF16_D(DBF)>
static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed,
size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) {
const hn::Repartition<uint8_t, DBF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= 2 * N16) {
HWY_UNROLL(1)
for (; i <= num - 2 * N16; i += 2 * N16) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
hn::StoreU(bf0, dbf, out_bf + i);
hn::StoreU(bf1, dbf, out_bf + i + N16);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 2 * N16);
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
hn::StoreU(bf0, dbf, padded);
hn::StoreU(bf1, dbf, padded + N16);
hwy::CopyBytes(padded, out_bf + i, remaining * sizeof(padded[0]));
}
}
// Decodes `num` values from `in_packed` to `out_f`.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num, float* HWY_RESTRICT out_f) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 4 * NF) {
HWY_UNROLL(1)
for (; i <= num - 4 * NF; i += 4 * NF) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
hn::StoreU(f0, df, out_f + i + NF * 0);
hn::StoreU(f1, df, out_f + i + NF * 1);
hn::StoreU(f2, df, out_f + i + NF * 2);
hn::StoreU(f3, df, out_f + i + NF * 3);
}
}
const size_t remaining = num - i;
HWY_DASSERT(remaining < 4 * NF);
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
hn::StoreU(f0, df, padded + NF * 0);
hn::StoreU(f1, df, padded + NF * 1);
hn::StoreU(f2, df, padded + NF * 2);
hn::StoreU(f3, df, padded + NF * 3);
hwy::CopyBytes(padded, out_f + i, remaining * sizeof(padded[0]));
}
}
// Fused decode and dot product with bf16 into four output accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num,
const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using V8 = hn::Vec<decltype(d8)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
size_t i = 0;
if (num >= 2 * N16) {
HWY_UNROLL(1)
for (; i <= num - 2 * N16; i += 2 * N16) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VBF v0 = hn::LoadU(dbf, vec_aligned + i);
const VBF v1 = hn::LoadU(dbf, vec_aligned + i + N16);
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3);
}
}
const size_t remaining = num - i;
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0]));
const VBF v0 = hn::LoadU(dbf, padded);
const VBF v1 = hn::LoadU(dbf, padded + N16);
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3);
}
}
// Fused decode and dot product with f32 into four output accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num, const float* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
size_t i = 0;
if (num >= 4 * NF) {
HWY_UNROLL(1)
for (; i <= num - 4 * NF; i += 4 * NF) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VF v0 = hn::LoadU(df, vec_aligned + i + NF * 0);
const VF v1 = hn::LoadU(df, vec_aligned + i + NF * 1);
const VF v2 = hn::LoadU(df, vec_aligned + i + NF * 2);
const VF v3 = hn::LoadU(df, vec_aligned + i + NF * 3);
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
sum0 = hn::MulAdd(f0, v0, sum0);
sum1 = hn::MulAdd(f1, v1, sum1);
sum2 = hn::MulAdd(f2, v2, sum2);
sum3 = hn::MulAdd(f3, v3, sum3);
}
}
const size_t remaining = num - i;
if (remaining != 0) {
const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining);
HWY_ALIGN float padded[4 * hn::MaxLanes(df)];
hwy::ZeroBytes(padded, sizeof(padded));
hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0]));
const VF v0 = hn::LoadU(df, padded + NF * 0);
const VF v1 = hn::LoadU(df, padded + NF * 1);
const VF v2 = hn::LoadU(df, padded + NF * 2);
const VF v3 = hn::LoadU(df, padded + NF * 3);
VF f0, f1, f2, f3;
Dec4F(df, packed, f0, f1, f2, f3);
sum0 = hn::MulAdd(f0, v0, sum0);
sum1 = hn::MulAdd(f1, v1, sum1);
sum2 = hn::MulAdd(f2, v2, sum2);
sum3 = hn::MulAdd(f3, v3, sum3);
}
}
// Fused decode and dot product with even-odd bf16 into four f32 accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num,
const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using V8 = hn::Vec<decltype(d8)>;
using VBF = hn::Vec<decltype(dbf)>;
const size_t N16 = hn::Lanes(dbf);
HWY_DASSERT(num % (2 * N16) == 0); // whole SFP vector -> 2x bf16
HWY_UNROLL(1)
for (size_t i = 0; i < num; i += 2 * N16) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VBF ve = hn::LoadU(dbf, vec_aligned + i);
const VBF vo = hn::LoadU(dbf, vec_aligned + i + N16);
VBF be, bo;
DecEvenOdd(dbf, packed, be, bo);
sum0 = hn::ReorderWidenMulAccumulate(df, be, ve, sum0, sum1);
sum2 = hn::ReorderWidenMulAccumulate(df, bo, vo, sum2, sum3);
}
}
// Fused decode and dot product with even-odd f32 into four f32 accumulators.
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed,
size_t num,
const float* HWY_RESTRICT vec_aligned,
hn::Vec<DF>& sum0, hn::Vec<DF>& sum1,
hn::Vec<DF>& sum2, hn::Vec<DF>& sum3) {
const hn::Repartition<uint8_t, DF> d8;
using V8 = hn::Vec<decltype(d8)>;
using VF = hn::Vec<decltype(df)>;
const size_t NF = hn::Lanes(df);
HWY_DASSERT(num % (4 * NF) == 0); // whole SFP vector -> 4x f32
HWY_UNROLL(1)
for (size_t i = 0; i < num; i += 4 * NF) {
const V8 packed = hn::LoadU(d8, &in_packed->byte + i);
const VF ve0 = hn::LoadU(df, vec_aligned + i + NF * 0);
const VF vo0 = hn::LoadU(df, vec_aligned + i + NF * 1);
const VF ve1 = hn::LoadU(df, vec_aligned + i + NF * 2);
const VF vo1 = hn::LoadU(df, vec_aligned + i + NF * 3);
VF fe0, fo0, fe1, fo1;
DecEvenOddF(df, packed, fe0, fo0, fe1, fo1);
sum0 = hn::MulAdd(fe0, ve0, sum0);
sum1 = hn::MulAdd(fo0, vo0, sum1);
sum2 = hn::MulAdd(fe1, ve1, sum2);
sum3 = hn::MulAdd(fo1, vo1, sum3);
}
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const hn::Rebind<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
f0 = hn::PromoteTo(df, bf0);
f1 = hn::PromoteTo(df, bf1);
}
private:
// Wrappers to avoid code duplication across float/bf16 input types and
// the main loop/remainder.
// Returns vector of packed bytes for callers to StoreU or StoreN.
template <class D16, HWY_IF_U16_D(D16),
class V8 = hn::Vec<hn::Repartition<uint8_t, D16>>>
static HWY_INLINE V8 Enc2U(D16 d16, const hn::Vec<D16> w0,
const hn::Vec<D16> w1) {
const hn::Repartition<uint8_t, D16> d8;
// Although more expensive on AVX3, in-order packing enables streaming
// decompression without fixed-size packets.
const V8 lo = hn::ConcatEven(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0));
const V8 hi = hn::ConcatOdd(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0));
return EncBytes(d8, lo, hi);
}
template <class DBF, HWY_IF_BF16_D(DBF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) {
const hn::Repartition<uint16_t, DBF> d16;
const size_t N16 = hn::Lanes(d16);
using V16 = hn::Vec<decltype(d16)>;
const V16 w0 = hn::BitCast(d16, hn::LoadU(dbf, in));
const V16 w1 = hn::BitCast(d16, hn::LoadU(dbf, in + N16));
return Enc2U(d16, w0, w1);
}
// Truncates two f32 to bf16, in lane order, without rounding (see Enc4F).
template <class DBF, class DF = hn::RepartitionToWide<DBF>>
static HWY_INLINE hn::Vec<DBF> Truncate2To(DBF dbf, hn::Vec<DF> f0,
hn::Vec<DF> f1) {
const hn::RebindToUnsigned<DBF> d16;
using V16 = hn::Vec<decltype(d16)>;
const V16 u0 = BitCast(d16, f0);
const V16 u1 = BitCast(d16, f1);
return BitCast(DBF(), HWY_IS_LITTLE_ENDIAN ? ConcatOdd(d16, u1, u0)
: ConcatEven(d16, u1, u0));
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) {
const hn::Repartition<uint16_t, DF> d16;
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using VF = hn::Vec<decltype(df)>;
using V16 = hn::Vec<decltype(d16)>;
const size_t NF = hn::Lanes(df);
const VF f0 = hn::LoadU(df, in + NF * 0);
const VF f1 = hn::LoadU(df, in + NF * 1);
const VF f2 = hn::LoadU(df, in + NF * 2);
const VF f3 = hn::LoadU(df, in + NF * 3);
// Chop off the lower 16 bits instead of OrderedDemote2To, which rounds to
// the nearest bf16, because EncBytes will round again.
const V16 w0 = hn::BitCast(d16, Truncate2To(dbf, f0, f1));
const V16 w1 = hn::BitCast(d16, Truncate2To(dbf, f2, f3));
return Enc2U(d16, w0, w1);
}
template <class D16, HWY_IF_U16_D(D16),
class V8 = hn::Vec<hn::Repartition<uint8_t, D16>>>
static HWY_INLINE void Dec2U(D16 d16, V8 packed, hn::Vec<D16>& w0,
hn::Vec<D16>& w1) {
const hn::Repartition<uint8_t, D16> d8;
V8 lo, hi;
DecBytes(d8, packed, lo, hi);
w0 = hn::BitCast(d16, hn::InterleaveWholeLower(d8, lo, hi));
w1 = hn::BitCast(d16, hn::InterleaveWholeUpper(d8, lo, hi));
}
template <class DBF, HWY_IF_BF16_D(DBF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
static HWY_INLINE void Dec2B(DBF dbf, V8 packed, hn::Vec<DBF>& bf0,
hn::Vec<DBF>& bf1) {
const hn::Repartition<uint16_t, DBF> d16;
using V16 = hn::Vec<decltype(d16)>;
V16 w0, w1;
Dec2U(d16, packed, w0, w1);
bf0 = hn::BitCast(dbf, w0);
bf1 = hn::BitCast(dbf, w1);
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec<DF>& f0,
hn::Vec<DF>& f1, hn::Vec<DF>& f2,
hn::Vec<DF>& f3) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
f0 = hn::PromoteLowerTo(df, bf0);
f1 = hn::PromoteUpperTo(df, bf0);
f2 = hn::PromoteLowerTo(df, bf1);
f3 = hn::PromoteUpperTo(df, bf1);
}
template <class DBF, HWY_IF_BF16_D(DBF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DBF>>>
static HWY_INLINE void DecEvenOdd(DBF dbf, V8 packed, hn::Vec<DBF>& even,
hn::Vec<DBF>& odd) {
const hn::Repartition<uint8_t, DBF> d8;
V8 lo, hi;
DecBytes(d8, packed, lo, hi);
// (Supported since Highway 1.2)
even = hn::BitCast(dbf, hn::InterleaveEven(d8, lo, hi));
odd = hn::BitCast(dbf, hn::InterleaveOdd(d8, lo, hi));
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Repartition<uint8_t, DF>>>
static HWY_INLINE void DecEvenOddF(DF df, V8 packed, hn::Vec<DF>& even0,
hn::Vec<DF>& odd0, hn::Vec<DF>& even1,
hn::Vec<DF>& odd1) {
const hn::Repartition<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
VBF even_bf, odd_bf;
DecEvenOdd(dbf, packed, even_bf, odd_bf);
even0 = hn::PromoteLowerTo(df, even_bf);
odd0 = hn::PromoteLowerTo(df, odd_bf);
even1 = hn::PromoteUpperTo(df, even_bf);
odd1 = hn::PromoteUpperTo(df, odd_bf);
}
}; // SfpCodec
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_