Cascaded summation for Softmax

This can affect generation results after a few hundred tokens.

Also remove profiler from DecompressAndCall, use Add instead of +=,
use PackedSpan for args and remove alignment requirement.
Changing accumulation order in AssimilateCascadedSums updates dot_test thresholds.

PiperOrigin-RevId: 676891797
This commit is contained in:
Jan Wassenberg 2024-09-20 10:30:42 -07:00 committed by Copybara-Service
parent 09bc8d62cc
commit 35fdf848c7
5 changed files with 130 additions and 80 deletions

View File

@ -45,7 +45,6 @@
// After highway.h
#include "compression/nuq-inl.h"
#include "compression/sfp-inl.h"
#include "hwy/profiler.h" // also uses SIMD
HWY_BEFORE_NAMESPACE();
namespace gcpp {
@ -507,13 +506,8 @@ HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
template <class D, typename WeightT, typename VecT, class Kernel>
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
const size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
const size_t num, const Kernel& kernel) {
PROFILER_FUNC;
HWY_DASSERT(hn::IsAligned(hn::Repartition<VecT, D>(), vec_aligned));
const auto v_span = MakeSpan(vec_aligned, num);
const PackedSpan<const VecT> vec,
const Kernel& kernel) {
// Decompressed inputs
using V = hn::Vec<decltype(d)>;
V w0, w1, w2, w3, v0, v1, v2, v3;
@ -532,26 +526,26 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
const size_t N = hn::Lanes(d);
size_t i = 0;
if (num >= 4 * N) {
for (; i <= num - 4 * N; i += 4 * N) {
if (vec.num >= 4 * N) {
for (; i <= vec.num - 4 * N; i += 4 * N) {
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1);
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3);
Decompress2(d, v_span, i + 0 * N, v0, v1);
Decompress2(d, v_span, i + 2 * N, v2, v3);
Decompress2(d, vec, i + 0 * N, v0, v1);
Decompress2(d, vec, i + 2 * N, v2, v3);
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
comp0, comp1, comp2, comp3);
}
}
size_t remaining = num - i;
size_t remaining = vec.num - i;
HWY_DASSERT(remaining < 4 * N);
if (HWY_UNLIKELY(remaining != 0)) {
using T = hn::TFromD<D>;
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
DecompressAndZeroPad(d, v_span, i, padded_v, remaining);
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
// 1..4 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
@ -566,13 +560,8 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
// Same as above, but single input array. Used by RMSNorm.
template <class D, typename VecT, class Kernel>
HWY_INLINE float DecompressAndCall(D d, const VecT* HWY_RESTRICT vec_aligned,
const size_t num, const Kernel& kernel) {
PROFILER_FUNC;
HWY_DASSERT(hn::IsAligned(hn::Repartition<VecT, D>(), vec_aligned));
const auto v_span = MakeSpan(vec_aligned, num);
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec,
const Kernel& kernel) {
// Decompressed inputs
using V = hn::Vec<decltype(d)>;
V v0, v1, v2, v3;
@ -591,21 +580,21 @@ HWY_INLINE float DecompressAndCall(D d, const VecT* HWY_RESTRICT vec_aligned,
const size_t N = hn::Lanes(d);
size_t i = 0;
if (num >= 4 * N) {
for (; i <= num - 4 * N; i += 4 * N) {
Decompress2(d, v_span, i + 0 * N, v0, v1);
Decompress2(d, v_span, i + 2 * N, v2, v3);
if (vec.num >= 4 * N) {
for (; i <= vec.num - 4 * N; i += 4 * N) {
Decompress2(d, vec, i + 0 * N, v0, v1);
Decompress2(d, vec, i + 2 * N, v2, v3);
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
comp0, comp1, comp2, comp3);
}
}
size_t remaining = num - i;
size_t remaining = vec.num - i;
HWY_DASSERT(remaining < 4 * N);
if (HWY_UNLIKELY(remaining != 0)) {
HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)];
DecompressAndZeroPad(d, v_span, i, padded_v, remaining);
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
// 1..4 whole vectors, possibly zero-padded.
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {

View File

@ -244,34 +244,34 @@ struct DotKernelCompensated {
// Default kernel
template <class D, typename WeightT, typename VecT>
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
const VecT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelCompensated());
}
// Adapter for a single pointer, no bounds checking.
template <typename WeightT, typename VecT>
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec_aligned,
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec,
size_t num) {
const hn::ScalableTag<VecT> d;
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec_aligned, num);
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
}
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <size_t kCapacity, typename VecT>
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
const VecT* vec_aligned, size_t num) {
const VecT* vec, size_t num) {
const hn::ScalableTag<VecT> d;
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
}
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
template <typename MatT, size_t kCapacity, typename VecT>
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
const VecT* vec_aligned, size_t num) {
const VecT* vec, size_t num) {
const hn::ScalableTag<VecT> d;
return w.scale() *
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -149,8 +149,8 @@ struct DotKernelNaive {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelNaive());
const VecT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
}
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
@ -196,16 +196,15 @@ struct DotKernelKahan {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelKahan());
const VecT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan());
}
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelCompensated());
}
@ -259,10 +258,9 @@ struct DotKernelTwoProdFast {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelTwoProdFast());
}
@ -315,10 +313,10 @@ struct DotKernelMulTwoSum {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelMulTwoSum());
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelMulTwoSum());
}
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
@ -370,10 +368,9 @@ struct DotKernelTwoProdAdd {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelTwoProdAdd());
}
@ -437,9 +434,10 @@ struct DotKernelPairwise {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
size_t w_ofs, const VecT* HWY_RESTRICT vec_aligned,
size_t w_ofs, const VecT* HWY_RESTRICT vec,
size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelPairwise());
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
DotKernelPairwise());
}
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
@ -531,8 +529,8 @@ struct DotKernelComp2 {
template <class D, typename WeightT, typename VecT>
HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelComp2());
const VecT* HWY_RESTRICT vec, size_t num) {
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
}
template <class D, typename WeightT, typename VecT>
@ -703,10 +701,10 @@ class DotStats {
// Naive and OnlyTwoProd are considerably worse. >10x is for narrower
// vectors, compared to AVX-512. GeometricMean overflows, must use Mean.
ASSERT_INSIDE(kNaive, 1.01, s_muls[kNaive].Mean(), 16.0);
ASSERT_INSIDE(kOnlyTwoProd, 1.01, s_muls[kOnlyTwoProd].Mean(), 13.0);
ASSERT_INSIDE(kOnlyTwoProd, 1.01, s_muls[kOnlyTwoProd].Mean(), 73.0);
// Kahan (FastTwoSum) is decent:
ASSERT_INSIDE(kKahan, 1.001, s_muls[kKahan].Mean(), 4.1);
ASSERT_INSIDE(kKahan, 1.0005, s_muls[kKahan].Mean(), 4.1);
ASSERT_INSIDE(kKahan, 1.001f, s_muls[kKahan].Max(), 14.1f);
ASSERT_INSIDE(kKahan, 1.0, s_muls[kKahan].GeometricMean(), 1.6);
@ -718,7 +716,7 @@ class DotStats {
ASSERT_INSIDE(kAddTwoSum, 1.0005, s_muls[kAddTwoSum].Mean(), 2.2);
ASSERT_INSIDE(kAddTwoSum, 1.0, s_muls[kAddTwoSum].GeometricMean(), 1.3);
ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.5);
ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.6);
}
// Absolute error; larger is worse.
@ -736,12 +734,12 @@ class DotStats {
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_l1s[kOnlyTwoProd].Mean(), 2E-2);
// Kahan (FastTwoSum) is decent:
ASSERT_INSIDE(kKahan, 3.9E-4, s_l1s[kKahan].Mean(), 1E-3);
ASSERT_INSIDE(kKahan, 1.1E-3f, s_l1s[kKahan].Max(), 3.2E-3f);
ASSERT_INSIDE(kKahan, 3E-4, s_l1s[kKahan].Mean(), 1E-3);
ASSERT_INSIDE(kKahan, 6E-4f, s_l1s[kKahan].Max(), 3.2E-3f);
// But can be nearly halved via TwoProducts:
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.0E-3f);
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f);
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4);

View File

@ -118,15 +118,15 @@ template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) {
VF err;
sum = TwoSums(df, sum, v, err);
sum_err += err;
sum_err = hn::Add(sum_err, err);
}
// Combines two cascaded sum vectors, typically from unrolling/parallelization.
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err,
VF& sum, VF& sum_err) {
sum_err = hn::Add(sum_err, other_sum_err);
UpdateCascadedSums(df, other_sum, sum, sum_err);
sum_err += other_sum_err;
}
// Reduces cascaded sums, to a single value. Slow, call outside of loops.
@ -134,14 +134,31 @@ template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
const size_t N = hn::Lanes(df);
using TF = hn::TFromD<DF>;
// For non-scalable wide vectors, reduce loop iterations below by recursing
// once or twice for halves of 256-bit or 512-bit vectors.
if constexpr (!HWY_HAVE_SCALABLE) {
if constexpr (hn::Lanes(df) > 16 / sizeof(TF)) {
const hn::Half<DF> dfh;
using VFH = hn::Vec<decltype(dfh)>;
VFH sum0 = hn::LowerHalf(dfh, sum);
VFH sum_err0 = hn::LowerHalf(dfh, sum_err);
const VFH sum1 = hn::UpperHalf(dfh, sum);
const VFH sum_err1 = hn::UpperHalf(dfh, sum_err);
AssimilateCascadedSums(dfh, sum1, sum_err1, sum0, sum_err0);
return ReduceCascadedSums(dfh, sum0, sum_err0);
}
}
TF total = TF{0.0};
TF total_err = TF{0.0};
for (size_t i = 0; i < N; ++i) {
TF err;
total_err += hn::ExtractLane(sum_err, i);
total = TwoSum(total, hn::ExtractLane(sum, i), err);
total_err += err;
}
return total + total_err + hn::ReduceSum(df, sum_err);
return total + total_err;
}
// NOLINTNEXTLINE(google-readability-namespace-comments)

View File

@ -145,7 +145,8 @@ namespace detail {
template <typename VecT>
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
const hn::ScalableTag<float> df;
const float l2 = DecompressAndCall(df, x, size, DotKernelCompensated());
const float l2 =
DecompressAndCall(df, MakeSpan(x, size), DotKernelCompensated());
constexpr float kEps = 1e-6f; // avoid divide by zero
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
}
@ -503,8 +504,54 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
MulByConstAndAdd(c, x, out, size, size);
}
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
// instead of FastTwoSums because the magnitude of the initial sum is not
// always greater than the next input, and this does actually change the e2e
// generation results. Note that Kahan summation differs in that it first adds
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
// and comp* here have shorter dependency chains.
struct KernelCascadedSum {
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
VF& comp3) const {
VF serr0, serr1, serr2, serr3;
sum0 = TwoSums(df, sum0, w0, serr0);
sum1 = TwoSums(df, sum1, w1, serr1);
sum2 = TwoSums(df, sum2, w2, serr2);
sum3 = TwoSums(df, sum3, w3, serr3);
comp0 = hn::Add(comp0, serr0);
comp1 = hn::Add(comp1, serr1);
comp2 = hn::Add(comp2, serr2);
comp3 = hn::Add(comp3, serr3);
}
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
VF& comp0) const {
VF serr0;
sum0 = TwoSums(df, sum0, w0, serr0);
comp0 = hn::Add(comp0, serr0);
}
template <class DF, class VF = hn::Vec<DF>>
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
// Reduction tree: sum of all accumulators by pairs, then across lanes.
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
return ReduceCascadedSums(df, sum0, comp0);
}
};
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
const size_t mask_pos) {
PROFILER_FUNC;
HWY_DASSERT(size != 0);
HWY_DASSERT(mask_pos <= size);
@ -523,23 +570,22 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
// Subtract max (avoid precision loss for large exponents) and exponentiate.
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
#if HWY_TARGET & HWY_ALL_SVE
// Temporary workaround for buggy SVE codegen: avoid inlined
// Exp().
return hn::CallExp(d, hn::Sub(value, *pmax));
#else
return hn::Exp(d, hn::Sub(value, *pmax));
#endif
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
return hn::CallExp(d, hn::Sub(value, *pmax));
} else {
return hn::Exp(d, hn::Sub(value, *pmax));
}
});
V sum = hn::Zero(d);
V* psum = &sum;
hn::Foreach(d, x, mask_pos, sum,
[psum](const auto d, const V value)
HWY_ATTR { *psum = hn::Add(*psum, value); });
// Normalize to probability distribution
const float mul = 1.0f / hn::ReduceSum(d, sum);
// Normalize to probability distribution. The exact sum seems like it should
// not make a huge difference. It halves the standard deviation of the sum of
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
// the generated text after a few hundred tokens.
const float sum_exp =
DecompressAndCall(d, MakeConstSpan(x, mask_pos), KernelCascadedSum());
// Double-precision reciprocal does not appear to affect the results.
const float mul = 1.0f / sum_exp;
MulByConst(mul, x, size, mask_pos);
}