// Copyright 2023 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // Normal include guard to placate lint. #ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ #define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ #include #include #include "compression/shared.h" #include "hwy/base.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ // Actual per-target include guard. #if defined(THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE) == defined(HWY_TARGET_TOGGLE) #ifdef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE #undef THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE #else #define THIRD_PARTY_GEMMA_CPP_SFP_INL_TOGGLE #endif #include "hwy/detect_targets.h" #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; // For unsigned numbers with MSB zero, signed comparison is faster on x86. template HWY_INLINE hn::Mask SignedGt(DU du, hn::Vec a, hn::Vec b) { const hn::RebindToSigned di; return hn::RebindMask(du, hn::Gt(BitCast(di, a), hn::BitCast(di, b))); } template HWY_INLINE hn::Mask SignedLt(DU du, hn::Vec a, hn::Vec b) { return SignedGt(du, b, a); } // Saturated subtraction; returns 0 if the result would be negative. static inline size_t SubOr0(size_t a, size_t b) { return a > b ? a - b : 0; } // Encode/decode functions. class SfpCodec { public: // Returns 8-bit packed representation of `lo` and `hi` bytes of bf16. 31 ops. // Implementation detail, public because called by test. template static HWY_INLINE hn::Vec EncBytes(D d, const hn::Vec lo, const hn::Vec hi) { const hn::Vec k1 = hn::Set(d, 1u); const hn::Vec k80 = hn::Set(d, 0x80u); // Copy sign for later insertion. const hn::Vec sign_in_msb = hi; // Biased exponent = lower 7 bits of hi and MSB of lo. Modified below. hn::Vec biased_e = hn::Or(hn::Add(hi, hi), hn::ShiftRight<7>(lo)); HWY_ASSERT(hn::AllTrue(d, hn::Lt(biased_e, k80))); // <= 2^0 // Clear MSB to isolate the mantissa and enable signed comparisons, then // shift right by *one* (plus 1 to undo the prior add/left-shift) to leave // headroom for overflow during rounding. const hn::Vec m6 = hn::ShiftRight<2>(hn::Add(lo, lo)); // The place to round depends on whether the exponent is large (>= -7) - if // so, we retain three mantissa bits, otherwise two. However, rounding can // also cause the exponent to increase. We first choose a threshold that // rounds up to 1.0*2^-7 for both two and three bit mantissas: // >= 1.1111 * 2^-8 (0.007568359375). This entails the exponent being // greater, or equal and the mantissa > (1111000 >> 1) - 1 = 0x3B. const hn::Vec kMinLargeE = hn::Set(d, 127 - 8); const hn::Mask is_large_before_round = hn::Or( SignedGt(d, biased_e, kMinLargeE), hn::And(hn::Eq(biased_e, kMinLargeE), SignedGt(d, m6, Set(d, 0x3B)))); // To retain the most-significant 3 or 2 mantissa bits, we will right-shift // by is_large_before_round ? 3 : 4. Variable Shr is expensive for 8-bit // elements, so (<< 1) if is_large_before_round, then always (>> 4). const hn::Vec m_shl4 = hn::MaskedAddOr(m6, is_large_before_round, m6, m6); // Before shifting (truncation), round to nearest even to reduce bias. If // the lowest remaining mantissa bit is odd, increase the offset. Example // with the lowest remaining bit (left) and next lower two bits; the // latter, plus two more, will be truncated. // 0[00] + 1 = 0[01] // 0[01] + 1 = 0[10] // 0[10] + 1 = 0[11] (round down toward even) // 0[11] + 1 = 1[00] (round up) // 1[00] + 10 = 1[10] // 1[01] + 10 = 1[11] // 1[10] + 10 = C0[00] (round up toward even with C=1 carry out) // 1[11] + 10 = C0[01] (round up toward even with C=1 carry out) const hn::Vec odd_bit = hn::And(hn::ShiftRight<4>(m_shl4), k1); const hn::Vec rounded = hn::Add(m_shl4, hn::Add(odd_bit, Set(d, 7))); // Update the exponent if rounding overflowed. const hn::Vec carry_bit = hn::IfThenElse(is_large_before_round, k80, hn::Set(d, 0x40u)); const hn::Vec carry_clear = hn::AndNot(carry_bit, rounded); HWY_DASSERT(hn::AllTrue(d, hn::Lt(carry_clear, carry_bit))); const hn::Mask is_overflow = hn::Ne(carry_clear, rounded); biased_e = hn::MaskedAddOr(biased_e, is_overflow, biased_e, k1); HWY_DASSERT(hn::AllTrue(d, hn::Lt(biased_e, Set(d, 128)))); // Detect if zero or the min exponent. const hn::Vec kMinNormal = hn::Set(d, 127 - 23); const hn::Mask is_zero = SignedLt(d, biased_e, kMinNormal); const hn::Mask is_min = hn::Eq(biased_e, kMinNormal); // 1.1110xxx * 2^-8 was considered small above, and thus rounded up to 2^-7, // which the decoder will consider large, and expect 3 mantissa bits. If we // set the threshold above to 1.111, then it does NOT round up. Thus we // check exponent >= -7 *after* rounding. const hn::Mask is_large = SignedGt(d, biased_e, hn::Set(d, 127 - 8)); // To extract and pack the mantissa, only is_large matters. Either it // matches is_large_before_round, or the rounding resulted in mantissa=0, so // we either extract two or three bits by shifting out the lower 5..6 bits. // is_large_before is_large rounded want // 0 0 0Cmm???? mm // 0 1 0100???? 000 // 1 0 impossible - // 1 1 Cmmm???0 mmm hn::Vec m = hn::ShiftRight<4>(carry_clear); HWY_DASSERT(hn::AllTrue( d, SignedLt(d, m, hn::IfThenElse(is_large, hn::Set(d, 8), hn::Set(d, 4))))); // 1.0 * 2^-23 has the same encoding as zero, so round it up to 1.01. m = hn::MaskedMaxOr(m, is_min, m, k1); const hn::Vec e_bias = hn::IfThenElse( is_large, hn::Set(d, hwy::BitCastScalar(static_cast(15 - 127))), hn::Set(d, hwy::BitCastScalar(static_cast(23 - 127)))); const hn::Vec e = hn::Add(biased_e, e_bias); HWY_DASSERT( hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, e), hn::Set(d, 16)))); // Shift exponent left 2 or 3 bits to make space for `m`. const hn::Vec em = hn::Or(m, hn::ShiftLeft<2>(hn::MaskedAddOr(e, is_large, e, e))); HWY_DASSERT(hn::AllTrue(d, hn::Lt(hn::IfThenZeroElse(is_zero, em), k80))); const hn::Vec encoded = hn::BitwiseIfThenElse(k80, sign_in_msb, em); // Doing this last ensures -0 is replaced with 0. return hn::IfThenZeroElse(is_zero, encoded); } // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 3 ops (AVX-512). #if HWY_TARGET <= HWY_AVX3_DL || HWY_IDE template static HWY_INLINE void DecBytes(D d, hn::Vec encoded, hn::Vec& lo, hn::Vec& hi) { const hn::Vec k80 = hn::Set(d, 0x80u); HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved // Two 2x64 table lookups for lo/hi. alignas(64) static constexpr uint8_t kTblL0[64] = { 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0, 0x00, 0x20, 0x40, 0x60, 0x80, 0xA0, 0xC0, 0xE0}; alignas(64) static constexpr uint8_t kTblL1[64] = { 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0, 0x00, 0x10, 0x20, 0x30, 0x40, 0x50, 0x60, 0x70, 0x80, 0x90, 0xA0, 0xB0, 0xC0, 0xD0, 0xE0, 0xF0}; alignas(64) static constexpr uint8_t kTblH0[64] = { 0x00, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x34, 0x35, 0x35, 0x35, 0x35, 0x35, 0x35, 0x35, 0x35, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x37, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x38, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x39, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3A, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B, 0x3B}; alignas(64) static constexpr uint8_t kTblH1[64] = { 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3C, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3D, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3E, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F, 0x3F}; const hn::Vec tblL0 = hn::LoadU(d, kTblL0); const hn::Vec tblL1 = hn::LoadU(d, kTblL1); const hn::Vec tblH0 = hn::LoadU(d, kTblH0); const hn::Vec tblH1 = hn::LoadU(d, kTblH1); #if HWY_IDE // only let the IDE see portable code. const auto idx = hn::IndicesFromVec(hn::AndNot(k80, encoded)); #else // AVX-512-specific: index MSB is ignored, no need to clear. const hn::Indices512 idx{encoded.raw}; #endif hi = hn::TwoTablesLookupLanes(d, tblH0, tblH1, idx); lo = hn::TwoTablesLookupLanes(d, tblL0, tblL1, idx); hi = hn::OrAnd(hi, encoded, k80); // Insert sign bit } // Generic is only required for partial vectors (too small for tables). #undef SFP_IF_GENERIC_DEC #define SFP_IF_GENERIC_DEC(D) HWY_IF_V_SIZE_LE_D(D, 32) #else // Always enable the generic decoder. #undef SFP_IF_GENERIC_DEC #define SFP_IF_GENERIC_DEC(D) void* yes = nullptr #endif // Decodes u8 `encoded` into `lo` and `hi` bytes of bf16. 9 ops. template static HWY_INLINE void DecBytes(D d, hn::Vec encoded, hn::Vec& lo, hn::Vec& hi) { const hn::Vec k0 = hn::Zero(d); const hn::Vec k80 = hn::Set(d, 0x80u); HWY_DASSERT(hn::AllTrue(d, hn::Ne(encoded, k80))); // -0 is reserved // Copy sign for later insertion via BitwiseIfThenElse. const hn::Vec sign_in_msb = encoded; encoded = hn::AndNot(k80, encoded); // Special-case zero, negated so we can use MaskedAddOr. Signed comparison // is fine because we have cleared the sign bit. const hn::Mask is_nonzero = SignedGt(d, encoded, k0); // If bit 6 is clear, we have two mantissa bits, otherwise three. const hn::Mask is_small_e = SignedLt(d, encoded, hn::Set(d, 64)); // For encoded in [1, 8), hi = 0x34; encoded = 0x40 => hi = 0x3C including // (encoded >> 4) == 4, so add 0x38. const hn::Vec e_bias = hn::IfThenElse(is_small_e, hn::Set(d, 0x34), hn::Set(d, 0x38)); // The low byte of bf16 is encoded << (is_small_e ? 5 : 4). const hn::Vec shl1_if_small = hn::MaskedAddOr(encoded, is_small_e, encoded, encoded); lo = hn::ShiftLeft<4>(shl1_if_small); // Lower 4 bits always zero. HWY_DASSERT(hn::AllTrue(d, hn::Eq(hn::And(lo, Set(d, 15u)), hn::Zero(d)))); // The upper byte of bf16 is e_bias + (encoded >> (is_small_e ? 3 : 4)). const hn::Vec shr_3_or_4 = hn::ShiftRight<4>(shl1_if_small); // .. except when encoded=0: hi = 0, and lo is already 0. const hn::Vec e7 = hn::MaskedAddOr(k0, is_nonzero, e_bias, shr_3_or_4); HWY_DASSERT(hn::AllTrue(d, hn::Lt(e7, Set(d, 64u)))); // <= 0x3F // .. also insert the sign bit. hi = hn::BitwiseIfThenElse(k80, sign_in_msb, e7); } // Encodes `num` bf16 values from `in_bf` to `out_packed`. Their magnitude // must be at most 1.875. template static HWY_INLINE void Enc(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in_bf, size_t num, SfpStream* HWY_RESTRICT out_packed) { const hn::Repartition d8; using V8 = hn::Vec; const size_t N16 = hn::Lanes(dbf); size_t i = 0; if (num >= 2 * N16) { HWY_UNROLL(1) for (; i <= num - 2 * N16; i += 2 * N16) { const V8 packed = Enc2B(dbf, in_bf + i); hn::StoreU(packed, d8, &out_packed->byte + i); } } const size_t remaining = num - i; HWY_DASSERT(remaining < 2 * N16); if (remaining != 0) { HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; hwy::ZeroBytes(padded, sizeof(padded)); hwy::CopyBytes(in_bf + i, padded, remaining * sizeof(padded[0])); const V8 packed = Enc2B(dbf, padded); hn::StoreN(packed, d8, &out_packed->byte + i, remaining); } } // Encodes `num` f32 values from `in_f` to `packed`. Their magnitude // must be at most 1.875. template static HWY_INLINE void Enc(DF df, const float* HWY_RESTRICT in_f, size_t num, SfpStream* HWY_RESTRICT out_packed) { const hn::Repartition d8; using V8 = hn::Vec; const size_t NF = hn::Lanes(df); size_t i = 0; if (num >= 4 * NF) { HWY_UNROLL(1) for (; i <= num - 4 * NF; i += 4 * NF) { const V8 packed = Enc4F(df, in_f + i); hn::StoreU(packed, d8, &out_packed->byte + i); } } const size_t remaining = num - i; HWY_DASSERT(remaining < 4 * NF); if (remaining != 0) { HWY_ALIGN float padded[4 * hn::MaxLanes(df)]; hwy::ZeroBytes(padded, sizeof(padded)); hwy::CopyBytes(in_f + i, padded, remaining * sizeof(padded[0])); const V8 packed = Enc4F(df, padded); hn::StoreN(packed, d8, &out_packed->byte + i, remaining); } } // Decodes `num` values from `in_packed` to `out_bf`. template static HWY_INLINE void Dec(DBF dbf, const SfpStream* HWY_RESTRICT in_packed, size_t num, hwy::bfloat16_t* HWY_RESTRICT out_bf) { const hn::Repartition d8; using V8 = hn::Vec; using VBF = hn::Vec; const size_t N16 = hn::Lanes(dbf); size_t i = 0; if (num >= 2 * N16) { HWY_UNROLL(1) for (; i <= num - 2 * N16; i += 2 * N16) { const V8 packed = hn::LoadU(d8, &in_packed->byte + i); VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); hn::StoreU(bf0, dbf, out_bf + i); hn::StoreU(bf1, dbf, out_bf + i + N16); } } const size_t remaining = num - i; HWY_DASSERT(remaining < 2 * N16); if (remaining != 0) { const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); hn::StoreN(bf0, dbf, out_bf + i, remaining); hn::StoreN(bf1, dbf, out_bf + i + N16, SubOr0(remaining, N16)); } } // Decodes `num` values from `in_packed` to `out_f`. template static HWY_INLINE void Dec(DF df, const SfpStream* HWY_RESTRICT in_packed, size_t num, float* HWY_RESTRICT out_f) { const hn::Repartition d8; using V8 = hn::Vec; using VF = hn::Vec; const size_t NF = hn::Lanes(df); size_t i = 0; if (num >= 4 * NF) { HWY_UNROLL(1) for (; i <= num - 4 * NF; i += 4 * NF) { const V8 packed = hn::LoadU(d8, &in_packed->byte + i); VF f0, f1, f2, f3; Dec4F(df, packed, f0, f1, f2, f3); hn::StoreU(f0, df, out_f + i + NF * 0); hn::StoreU(f1, df, out_f + i + NF * 1); hn::StoreU(f2, df, out_f + i + NF * 2); hn::StoreU(f3, df, out_f + i + NF * 3); } } const size_t remaining = num - i; HWY_DASSERT(remaining < 4 * NF); if (remaining != 0) { const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); VF f0, f1, f2, f3; Dec4F(df, packed, f0, f1, f2, f3); hn::StoreN(f0, df, out_f + i + 0 * NF, remaining); hn::StoreN(f1, df, out_f + i + 1 * NF, SubOr0(remaining, 1 * NF)); hn::StoreN(f2, df, out_f + i + 2 * NF, SubOr0(remaining, 2 * NF)); hn::StoreN(f3, df, out_f + i + 3 * NF, SubOr0(remaining, 3 * NF)); } } // Fused decode and dot product with even-odd bf16 into four f32 accumulators. template static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed, size_t num, const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, hn::Vec& sum0, hn::Vec& sum1, hn::Vec& sum2, hn::Vec& sum3) { const hn::Repartition d8; const hn::Repartition dbf; using V8 = hn::Vec; using VBF = hn::Vec; const size_t N16 = hn::Lanes(dbf); HWY_DASSERT(num % (2 * N16) == 0); // whole SFP vector -> 2x bf16 HWY_UNROLL(1) for (size_t i = 0; i < num; i += 2 * N16) { const V8 packed = hn::LoadU(d8, &in_packed->byte + i); const VBF ve = hn::LoadU(dbf, vec_aligned + i); const VBF vo = hn::LoadU(dbf, vec_aligned + i + N16); VBF be, bo; DecEvenOdd(dbf, packed, be, bo); sum0 = hn::ReorderWidenMulAccumulate(df, be, ve, sum0, sum1); sum2 = hn::ReorderWidenMulAccumulate(df, bo, vo, sum2, sum3); } } // Fused decode and dot product with even-odd f32 into four f32 accumulators. template static HWY_INLINE void DotEO(DF df, const SfpStream* HWY_RESTRICT in_packed, size_t num, const float* HWY_RESTRICT vec_aligned, hn::Vec& sum0, hn::Vec& sum1, hn::Vec& sum2, hn::Vec& sum3) { const hn::Repartition d8; using V8 = hn::Vec; using VF = hn::Vec; const size_t NF = hn::Lanes(df); HWY_DASSERT(num % (4 * NF) == 0); // whole SFP vector -> 4x f32 HWY_UNROLL(1) for (size_t i = 0; i < num; i += 4 * NF) { const V8 packed = hn::LoadU(d8, &in_packed->byte + i); const VF ve0 = hn::LoadU(df, vec_aligned + i + NF * 0); const VF vo0 = hn::LoadU(df, vec_aligned + i + NF * 1); const VF ve1 = hn::LoadU(df, vec_aligned + i + NF * 2); const VF vo1 = hn::LoadU(df, vec_aligned + i + NF * 3); VF fe0, fo0, fe1, fo1; DecEvenOddF(df, packed, fe0, fo0, fe1, fo1); sum0 = hn::MulAdd(fe0, ve0, sum0); sum1 = hn::MulAdd(fo0, vo0, sum1); sum2 = hn::MulAdd(fe1, ve1, sum2); sum3 = hn::MulAdd(fo1, vo1, sum3); } } template >>> static HWY_INLINE void Dec2(DF df, V8 packed, hn::Vec& f0, hn::Vec& f1) { const hn::Rebind dbf; using VBF = hn::Vec; VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); f0 = hn::PromoteTo(df, bf0); f1 = hn::PromoteTo(df, bf1); } template >> static HWY_INLINE void Dec2(DBF16 dbf16, V8 packed, hn::Vec& bf0, hn::Vec& bf1) { Dec2B(dbf16, packed, bf0, bf1); } private: // Wrappers to avoid code duplication across float/bf16 input types and // the main loop/remainder. // Returns vector of packed bytes for callers to StoreU or StoreN. template >> static HWY_INLINE V8 Enc2U(D16 d16, const hn::Vec w0, const hn::Vec w1) { const hn::Repartition d8; // Although more expensive on AVX3, in-order packing enables streaming // decompression without fixed-size packets. const V8 lo = hn::ConcatEven(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0)); const V8 hi = hn::ConcatOdd(d8, hn::BitCast(d8, w1), hn::BitCast(d8, w0)); return EncBytes(d8, lo, hi); } template >> static HWY_INLINE V8 Enc2B(DBF dbf, const hwy::bfloat16_t* HWY_RESTRICT in) { const hn::Repartition d16; const size_t N16 = hn::Lanes(d16); using V16 = hn::Vec; const V16 w0 = hn::BitCast(d16, hn::LoadU(dbf, in)); const V16 w1 = hn::BitCast(d16, hn::LoadU(dbf, in + N16)); return Enc2U(d16, w0, w1); } // Truncates two f32 to bf16, in lane order, without rounding (see Enc4F). template > static HWY_INLINE hn::Vec Truncate2To(DBF dbf, hn::Vec f0, hn::Vec f1) { const hn::RebindToUnsigned d16; using V16 = hn::Vec; const V16 u0 = BitCast(d16, f0); const V16 u1 = BitCast(d16, f1); return BitCast(DBF(), HWY_IS_LITTLE_ENDIAN ? ConcatOdd(d16, u1, u0) : ConcatEven(d16, u1, u0)); } template >> static HWY_INLINE V8 Enc4F(DF df, const float* HWY_RESTRICT in) { const hn::Repartition d16; const hn::Repartition dbf; using VF = hn::Vec; using V16 = hn::Vec; const size_t NF = hn::Lanes(df); const VF f0 = hn::LoadU(df, in + NF * 0); const VF f1 = hn::LoadU(df, in + NF * 1); const VF f2 = hn::LoadU(df, in + NF * 2); const VF f3 = hn::LoadU(df, in + NF * 3); // Chop off the lower 16 bits instead of OrderedDemote2To, which rounds to // the nearest bf16, because EncBytes will round again. const V16 w0 = hn::BitCast(d16, Truncate2To(dbf, f0, f1)); const V16 w1 = hn::BitCast(d16, Truncate2To(dbf, f2, f3)); return Enc2U(d16, w0, w1); } template >> static HWY_INLINE void Dec2U(D16 d16, V8 packed, hn::Vec& w0, hn::Vec& w1) { const hn::Repartition d8; V8 lo, hi; DecBytes(d8, packed, lo, hi); w0 = hn::BitCast(d16, hn::InterleaveWholeLower(d8, lo, hi)); w1 = hn::BitCast(d16, hn::InterleaveWholeUpper(d8, lo, hi)); } template >> static HWY_INLINE void Dec2B(DBF dbf, V8 packed, hn::Vec& bf0, hn::Vec& bf1) { const hn::Repartition d16; using V16 = hn::Vec; V16 w0, w1; Dec2U(d16, packed, w0, w1); bf0 = hn::BitCast(dbf, w0); bf1 = hn::BitCast(dbf, w1); } template >> static HWY_INLINE void Dec4F(DF df, V8 packed, hn::Vec& f0, hn::Vec& f1, hn::Vec& f2, hn::Vec& f3) { const hn::Repartition dbf; using VBF = hn::Vec; VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); f0 = hn::PromoteLowerTo(df, bf0); f1 = hn::PromoteUpperTo(df, bf0); f2 = hn::PromoteLowerTo(df, bf1); f3 = hn::PromoteUpperTo(df, bf1); } template >> static HWY_INLINE void DecEvenOdd(DBF dbf, V8 packed, hn::Vec& even, hn::Vec& odd) { const hn::Repartition d8; V8 lo, hi; DecBytes(d8, packed, lo, hi); // (Supported since Highway 1.2) even = hn::BitCast(dbf, hn::InterleaveEven(d8, lo, hi)); odd = hn::BitCast(dbf, hn::InterleaveOdd(d8, lo, hi)); } template >> static HWY_INLINE void DecEvenOddF(DF df, V8 packed, hn::Vec& even0, hn::Vec& odd0, hn::Vec& even1, hn::Vec& odd1) { const hn::Repartition dbf; using VBF = hn::Vec; VBF even_bf, odd_bf; DecEvenOdd(dbf, packed, even_bf, odd_bf); even0 = hn::PromoteLowerTo(df, even_bf); odd0 = hn::PromoteLowerTo(df, odd_bf); even1 = hn::PromoteUpperTo(df, even_bf); odd1 = hn::PromoteUpperTo(df, odd_bf); } }; // SfpCodec // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_