mirror of https://github.com/google/gemma.cpp.git
Cascaded summation for Softmax
This can affect generation results after a few hundred tokens. Also remove profiler from DecompressAndCall, use Add instead of +=, use PackedSpan for args and remove alignment requirement. Changing accumulation order in AssimilateCascadedSums updates dot_test thresholds. PiperOrigin-RevId: 676891797
This commit is contained in:
parent
09bc8d62cc
commit
35fdf848c7
|
|
@ -45,7 +45,6 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/nuq-inl.h"
|
#include "compression/nuq-inl.h"
|
||||||
#include "compression/sfp-inl.h"
|
#include "compression/sfp-inl.h"
|
||||||
#include "hwy/profiler.h" // also uses SIMD
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
@ -507,13 +506,8 @@ HWY_NOINLINE void DecompressAndZeroPad(DRaw d, const PackedSpan<Packed>& packed,
|
||||||
template <class D, typename WeightT, typename VecT, class Kernel>
|
template <class D, typename WeightT, typename VecT, class Kernel>
|
||||||
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
||||||
const size_t w_ofs,
|
const size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
const PackedSpan<const VecT> vec,
|
||||||
const size_t num, const Kernel& kernel) {
|
const Kernel& kernel) {
|
||||||
PROFILER_FUNC;
|
|
||||||
|
|
||||||
HWY_DASSERT(hn::IsAligned(hn::Repartition<VecT, D>(), vec_aligned));
|
|
||||||
const auto v_span = MakeSpan(vec_aligned, num);
|
|
||||||
|
|
||||||
// Decompressed inputs
|
// Decompressed inputs
|
||||||
using V = hn::Vec<decltype(d)>;
|
using V = hn::Vec<decltype(d)>;
|
||||||
V w0, w1, w2, w3, v0, v1, v2, v3;
|
V w0, w1, w2, w3, v0, v1, v2, v3;
|
||||||
|
|
@ -532,26 +526,26 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
||||||
|
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
if (num >= 4 * N) {
|
if (vec.num >= 4 * N) {
|
||||||
for (; i <= num - 4 * N; i += 4 * N) {
|
for (; i <= vec.num - 4 * N; i += 4 * N) {
|
||||||
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1);
|
Decompress2(d, w, w_ofs + i + 0 * N, w0, w1);
|
||||||
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3);
|
Decompress2(d, w, w_ofs + i + 2 * N, w2, w3);
|
||||||
Decompress2(d, v_span, i + 0 * N, v0, v1);
|
Decompress2(d, vec, i + 0 * N, v0, v1);
|
||||||
Decompress2(d, v_span, i + 2 * N, v2, v3);
|
Decompress2(d, vec, i + 2 * N, v2, v3);
|
||||||
|
|
||||||
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
kernel.Update4(d, w0, w1, w2, w3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
||||||
comp0, comp1, comp2, comp3);
|
comp0, comp1, comp2, comp3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t remaining = num - i;
|
size_t remaining = vec.num - i;
|
||||||
HWY_DASSERT(remaining < 4 * N);
|
HWY_DASSERT(remaining < 4 * N);
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
using T = hn::TFromD<D>;
|
using T = hn::TFromD<D>;
|
||||||
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
|
HWY_ALIGN T padded_w[4 * hn::MaxLanes(d)];
|
||||||
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
HWY_ALIGN T padded_v[4 * hn::MaxLanes(d)];
|
||||||
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
|
DecompressAndZeroPad(d, w, w_ofs + i, padded_w, remaining);
|
||||||
DecompressAndZeroPad(d, v_span, i, padded_v, remaining);
|
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
|
||||||
|
|
||||||
// 1..4 whole vectors, possibly zero-padded.
|
// 1..4 whole vectors, possibly zero-padded.
|
||||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||||
|
|
@ -566,13 +560,8 @@ HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const WeightT>& w,
|
||||||
|
|
||||||
// Same as above, but single input array. Used by RMSNorm.
|
// Same as above, but single input array. Used by RMSNorm.
|
||||||
template <class D, typename VecT, class Kernel>
|
template <class D, typename VecT, class Kernel>
|
||||||
HWY_INLINE float DecompressAndCall(D d, const VecT* HWY_RESTRICT vec_aligned,
|
HWY_INLINE float DecompressAndCall(D d, const PackedSpan<const VecT> vec,
|
||||||
const size_t num, const Kernel& kernel) {
|
const Kernel& kernel) {
|
||||||
PROFILER_FUNC;
|
|
||||||
|
|
||||||
HWY_DASSERT(hn::IsAligned(hn::Repartition<VecT, D>(), vec_aligned));
|
|
||||||
const auto v_span = MakeSpan(vec_aligned, num);
|
|
||||||
|
|
||||||
// Decompressed inputs
|
// Decompressed inputs
|
||||||
using V = hn::Vec<decltype(d)>;
|
using V = hn::Vec<decltype(d)>;
|
||||||
V v0, v1, v2, v3;
|
V v0, v1, v2, v3;
|
||||||
|
|
@ -591,21 +580,21 @@ HWY_INLINE float DecompressAndCall(D d, const VecT* HWY_RESTRICT vec_aligned,
|
||||||
|
|
||||||
const size_t N = hn::Lanes(d);
|
const size_t N = hn::Lanes(d);
|
||||||
size_t i = 0;
|
size_t i = 0;
|
||||||
if (num >= 4 * N) {
|
if (vec.num >= 4 * N) {
|
||||||
for (; i <= num - 4 * N; i += 4 * N) {
|
for (; i <= vec.num - 4 * N; i += 4 * N) {
|
||||||
Decompress2(d, v_span, i + 0 * N, v0, v1);
|
Decompress2(d, vec, i + 0 * N, v0, v1);
|
||||||
Decompress2(d, v_span, i + 2 * N, v2, v3);
|
Decompress2(d, vec, i + 2 * N, v2, v3);
|
||||||
|
|
||||||
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
kernel.Update4(d, v0, v1, v2, v3, v0, v1, v2, v3, sum0, sum1, sum2, sum3,
|
||||||
comp0, comp1, comp2, comp3);
|
comp0, comp1, comp2, comp3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t remaining = num - i;
|
size_t remaining = vec.num - i;
|
||||||
HWY_DASSERT(remaining < 4 * N);
|
HWY_DASSERT(remaining < 4 * N);
|
||||||
if (HWY_UNLIKELY(remaining != 0)) {
|
if (HWY_UNLIKELY(remaining != 0)) {
|
||||||
HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)];
|
HWY_ALIGN float padded_v[4 * hn::MaxLanes(d)];
|
||||||
DecompressAndZeroPad(d, v_span, i, padded_v, remaining);
|
DecompressAndZeroPad(d, vec, i, padded_v, remaining);
|
||||||
|
|
||||||
// 1..4 whole vectors, possibly zero-padded.
|
// 1..4 whole vectors, possibly zero-padded.
|
||||||
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
for (size_t padded_pos = 0; padded_pos < remaining; padded_pos += N) {
|
||||||
|
|
|
||||||
|
|
@ -244,34 +244,34 @@ struct DotKernelCompensated {
|
||||||
// Default kernel
|
// Default kernel
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float Dot(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
const VecT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelCompensated());
|
DotKernelCompensated());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for a single pointer, no bounds checking.
|
// Adapter for a single pointer, no bounds checking.
|
||||||
template <typename WeightT, typename VecT>
|
template <typename WeightT, typename VecT>
|
||||||
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec_aligned,
|
HWY_INLINE float Dot(const WeightT* HWY_RESTRICT w, const VecT* vec,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
const hn::ScalableTag<VecT> d;
|
const hn::ScalableTag<VecT> d;
|
||||||
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec_aligned, num);
|
return Dot(d, MakeConstSpan(w, num), /*w_ofs=*/0, vec, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||||
template <size_t kCapacity, typename VecT>
|
template <size_t kCapacity, typename VecT>
|
||||||
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
|
HWY_INLINE float Dot(const std::array<float, kCapacity>& w, size_t w_ofs,
|
||||||
const VecT* vec_aligned, size_t num) {
|
const VecT* vec, size_t num) {
|
||||||
const hn::ScalableTag<VecT> d;
|
const hn::ScalableTag<VecT> d;
|
||||||
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
|
return Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used.
|
||||||
template <typename MatT, size_t kCapacity, typename VecT>
|
template <typename MatT, size_t kCapacity, typename VecT>
|
||||||
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
|
HWY_INLINE float Dot(const CompressedArray<MatT, kCapacity>& w, size_t w_ofs,
|
||||||
const VecT* vec_aligned, size_t num) {
|
const VecT* vec, size_t num) {
|
||||||
const hn::ScalableTag<VecT> d;
|
const hn::ScalableTag<VecT> d;
|
||||||
return w.scale() *
|
return w.scale() *
|
||||||
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec_aligned, num);
|
Dot(d, MakeConstSpan(w.data(), kCapacity), w_ofs, vec, num);
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -149,8 +149,8 @@ struct DotKernelNaive {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float DotNaive(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
const VecT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelNaive());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelNaive());
|
||||||
}
|
}
|
||||||
|
|
||||||
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
|
// https://en.wikipedia.org/wiki/Kahan_summation_algorithm: FastTwoSum.
|
||||||
|
|
@ -196,16 +196,15 @@ struct DotKernelKahan {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float DotKahan(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
const VecT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelKahan());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelKahan());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotCompensated(D d, const PackedSpan<const WeightT>& w,
|
||||||
size_t w_ofs,
|
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelCompensated());
|
DotKernelCompensated());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -259,10 +258,9 @@ struct DotKernelTwoProdFast {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotTwoProdFast(D d, const PackedSpan<const WeightT>& w,
|
||||||
size_t w_ofs,
|
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelTwoProdFast());
|
DotKernelTwoProdFast());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -315,10 +313,10 @@ struct DotKernelMulTwoSum {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotMulTwoSum(D d, const PackedSpan<const WeightT>& w,
|
||||||
size_t w_ofs,
|
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelMulTwoSum());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
|
DotKernelMulTwoSum());
|
||||||
}
|
}
|
||||||
|
|
||||||
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
|
// -Like Compensated, but only TwoProducts, no [Fast]TwoSums. This is only 10%
|
||||||
|
|
@ -370,10 +368,9 @@ struct DotKernelTwoProdAdd {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotTwoProdAdd(D d, const PackedSpan<const WeightT>& w,
|
||||||
size_t w_ofs,
|
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
||||||
const VecT* HWY_RESTRICT vec_aligned,
|
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num,
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
DotKernelTwoProdAdd());
|
DotKernelTwoProdAdd());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -437,9 +434,10 @@ struct DotKernelPairwise {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
|
HWY_INLINE float DotPairwise(D d, const PackedSpan<const WeightT>& w,
|
||||||
size_t w_ofs, const VecT* HWY_RESTRICT vec_aligned,
|
size_t w_ofs, const VecT* HWY_RESTRICT vec,
|
||||||
size_t num) {
|
size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelPairwise());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num),
|
||||||
|
DotKernelPairwise());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
|
// Hybrid of Pairwise and Compensated. 1.14x time vs. Kahan, but geomean mul
|
||||||
|
|
@ -531,8 +529,8 @@ struct DotKernelComp2 {
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
HWY_INLINE float DotComp2(D d, const PackedSpan<const WeightT>& w, size_t w_ofs,
|
||||||
const VecT* HWY_RESTRICT vec_aligned, size_t num) {
|
const VecT* HWY_RESTRICT vec, size_t num) {
|
||||||
return DecompressAndCall(d, w, w_ofs, vec_aligned, num, DotKernelComp2());
|
return DecompressAndCall(d, w, w_ofs, MakeSpan(vec, num), DotKernelComp2());
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class D, typename WeightT, typename VecT>
|
template <class D, typename WeightT, typename VecT>
|
||||||
|
|
@ -703,10 +701,10 @@ class DotStats {
|
||||||
// Naive and OnlyTwoProd are considerably worse. >10x is for narrower
|
// Naive and OnlyTwoProd are considerably worse. >10x is for narrower
|
||||||
// vectors, compared to AVX-512. GeometricMean overflows, must use Mean.
|
// vectors, compared to AVX-512. GeometricMean overflows, must use Mean.
|
||||||
ASSERT_INSIDE(kNaive, 1.01, s_muls[kNaive].Mean(), 16.0);
|
ASSERT_INSIDE(kNaive, 1.01, s_muls[kNaive].Mean(), 16.0);
|
||||||
ASSERT_INSIDE(kOnlyTwoProd, 1.01, s_muls[kOnlyTwoProd].Mean(), 13.0);
|
ASSERT_INSIDE(kOnlyTwoProd, 1.01, s_muls[kOnlyTwoProd].Mean(), 73.0);
|
||||||
|
|
||||||
// Kahan (FastTwoSum) is decent:
|
// Kahan (FastTwoSum) is decent:
|
||||||
ASSERT_INSIDE(kKahan, 1.001, s_muls[kKahan].Mean(), 4.1);
|
ASSERT_INSIDE(kKahan, 1.0005, s_muls[kKahan].Mean(), 4.1);
|
||||||
ASSERT_INSIDE(kKahan, 1.001f, s_muls[kKahan].Max(), 14.1f);
|
ASSERT_INSIDE(kKahan, 1.001f, s_muls[kKahan].Max(), 14.1f);
|
||||||
ASSERT_INSIDE(kKahan, 1.0, s_muls[kKahan].GeometricMean(), 1.6);
|
ASSERT_INSIDE(kKahan, 1.0, s_muls[kKahan].GeometricMean(), 1.6);
|
||||||
|
|
||||||
|
|
@ -718,7 +716,7 @@ class DotStats {
|
||||||
ASSERT_INSIDE(kAddTwoSum, 1.0005, s_muls[kAddTwoSum].Mean(), 2.2);
|
ASSERT_INSIDE(kAddTwoSum, 1.0005, s_muls[kAddTwoSum].Mean(), 2.2);
|
||||||
ASSERT_INSIDE(kAddTwoSum, 1.0, s_muls[kAddTwoSum].GeometricMean(), 1.3);
|
ASSERT_INSIDE(kAddTwoSum, 1.0, s_muls[kAddTwoSum].GeometricMean(), 1.3);
|
||||||
|
|
||||||
ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.5);
|
ASSERT_INSIDE(kPairwise, 1.0, s_muls[kPairwise].GeometricMean(), 1.6);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Absolute error; larger is worse.
|
// Absolute error; larger is worse.
|
||||||
|
|
@ -736,12 +734,12 @@ class DotStats {
|
||||||
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_l1s[kOnlyTwoProd].Mean(), 2E-2);
|
ASSERT_INSIDE(kOnlyTwoProd, 1E-3, s_l1s[kOnlyTwoProd].Mean(), 2E-2);
|
||||||
|
|
||||||
// Kahan (FastTwoSum) is decent:
|
// Kahan (FastTwoSum) is decent:
|
||||||
ASSERT_INSIDE(kKahan, 3.9E-4, s_l1s[kKahan].Mean(), 1E-3);
|
ASSERT_INSIDE(kKahan, 3E-4, s_l1s[kKahan].Mean(), 1E-3);
|
||||||
ASSERT_INSIDE(kKahan, 1.1E-3f, s_l1s[kKahan].Max(), 3.2E-3f);
|
ASSERT_INSIDE(kKahan, 6E-4f, s_l1s[kKahan].Max(), 3.2E-3f);
|
||||||
|
|
||||||
// But can be nearly halved via TwoProducts:
|
// But can be nearly halved via TwoProducts:
|
||||||
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
|
ASSERT_INSIDE(kAddTwoProd, 2.2E-4, s_l1s[kAddTwoProd].Mean(), 8E-4);
|
||||||
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.0E-3f);
|
ASSERT_INSIDE(kAddTwoProd, 4E-4f, s_l1s[kAddTwoProd].Max(), 2.1E-3f);
|
||||||
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
|
// Updating Kahan's FastTwoSums to TwoSums does help a bit.
|
||||||
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4);
|
ASSERT_INSIDE(kAddTwoSum, 1.5E-4, s_l1s[kAddTwoSum].Mean(), 5.2E-4);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -118,15 +118,15 @@ template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) {
|
void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) {
|
||||||
VF err;
|
VF err;
|
||||||
sum = TwoSums(df, sum, v, err);
|
sum = TwoSums(df, sum, v, err);
|
||||||
sum_err += err;
|
sum_err = hn::Add(sum_err, err);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Combines two cascaded sum vectors, typically from unrolling/parallelization.
|
// Combines two cascaded sum vectors, typically from unrolling/parallelization.
|
||||||
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err,
|
void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err,
|
||||||
VF& sum, VF& sum_err) {
|
VF& sum, VF& sum_err) {
|
||||||
|
sum_err = hn::Add(sum_err, other_sum_err);
|
||||||
UpdateCascadedSums(df, other_sum, sum, sum_err);
|
UpdateCascadedSums(df, other_sum, sum, sum_err);
|
||||||
sum_err += other_sum_err;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reduces cascaded sums, to a single value. Slow, call outside of loops.
|
// Reduces cascaded sums, to a single value. Slow, call outside of loops.
|
||||||
|
|
@ -134,14 +134,31 @@ template <class DF, HWY_IF_FLOAT3264_D(DF), class VF = hn::Vec<DF>>
|
||||||
hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
|
hn::TFromD<DF> ReduceCascadedSums(DF df, const VF sum, VF sum_err) {
|
||||||
const size_t N = hn::Lanes(df);
|
const size_t N = hn::Lanes(df);
|
||||||
using TF = hn::TFromD<DF>;
|
using TF = hn::TFromD<DF>;
|
||||||
|
// For non-scalable wide vectors, reduce loop iterations below by recursing
|
||||||
|
// once or twice for halves of 256-bit or 512-bit vectors.
|
||||||
|
if constexpr (!HWY_HAVE_SCALABLE) {
|
||||||
|
if constexpr (hn::Lanes(df) > 16 / sizeof(TF)) {
|
||||||
|
const hn::Half<DF> dfh;
|
||||||
|
using VFH = hn::Vec<decltype(dfh)>;
|
||||||
|
|
||||||
|
VFH sum0 = hn::LowerHalf(dfh, sum);
|
||||||
|
VFH sum_err0 = hn::LowerHalf(dfh, sum_err);
|
||||||
|
const VFH sum1 = hn::UpperHalf(dfh, sum);
|
||||||
|
const VFH sum_err1 = hn::UpperHalf(dfh, sum_err);
|
||||||
|
AssimilateCascadedSums(dfh, sum1, sum_err1, sum0, sum_err0);
|
||||||
|
return ReduceCascadedSums(dfh, sum0, sum_err0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
TF total = TF{0.0};
|
TF total = TF{0.0};
|
||||||
TF total_err = TF{0.0};
|
TF total_err = TF{0.0};
|
||||||
for (size_t i = 0; i < N; ++i) {
|
for (size_t i = 0; i < N; ++i) {
|
||||||
TF err;
|
TF err;
|
||||||
|
total_err += hn::ExtractLane(sum_err, i);
|
||||||
total = TwoSum(total, hn::ExtractLane(sum, i), err);
|
total = TwoSum(total, hn::ExtractLane(sum, i), err);
|
||||||
total_err += err;
|
total_err += err;
|
||||||
}
|
}
|
||||||
return total + total_err + hn::ReduceSum(df, sum_err);
|
return total + total_err;
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,8 @@ namespace detail {
|
||||||
template <typename VecT>
|
template <typename VecT>
|
||||||
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
|
float RMSNormMul(const VecT* HWY_RESTRICT x, size_t size) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
const float l2 = DecompressAndCall(df, x, size, DotKernelCompensated());
|
const float l2 =
|
||||||
|
DecompressAndCall(df, MakeSpan(x, size), DotKernelCompensated());
|
||||||
constexpr float kEps = 1e-6f; // avoid divide by zero
|
constexpr float kEps = 1e-6f; // avoid divide by zero
|
||||||
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
return 1.0f / sqrtf(l2 / StaticCast<float>(size) + kEps);
|
||||||
}
|
}
|
||||||
|
|
@ -503,8 +504,54 @@ static HWY_INLINE HWY_MAYBE_UNUSED void MulByConstAndAdd(
|
||||||
MulByConstAndAdd(c, x, out, size, size);
|
MulByConstAndAdd(c, x, out, size, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ORO Cascaded Summation, algorithm 6.11 from Handbook of Floating-Point
|
||||||
|
// Arithmetic. Note that Algorithm 6.7 (KBN) appears erroneous. We use TwoSums
|
||||||
|
// instead of FastTwoSums because the magnitude of the initial sum is not
|
||||||
|
// always greater than the next input, and this does actually change the e2e
|
||||||
|
// generation results. Note that Kahan summation differs in that it first adds
|
||||||
|
// comp* to w*, so each operation is serially dependent. By contrast, the sum*
|
||||||
|
// and comp* here have shorter dependency chains.
|
||||||
|
struct KernelCascadedSum {
|
||||||
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
|
HWY_INLINE void Update4(DF df, const VF w0, const VF w1, const VF w2,
|
||||||
|
const VF w3, VF, VF, VF, VF, VF& sum0, VF& sum1,
|
||||||
|
VF& sum2, VF& sum3, VF& comp0, VF& comp1, VF& comp2,
|
||||||
|
VF& comp3) const {
|
||||||
|
VF serr0, serr1, serr2, serr3;
|
||||||
|
sum0 = TwoSums(df, sum0, w0, serr0);
|
||||||
|
sum1 = TwoSums(df, sum1, w1, serr1);
|
||||||
|
sum2 = TwoSums(df, sum2, w2, serr2);
|
||||||
|
sum3 = TwoSums(df, sum3, w3, serr3);
|
||||||
|
|
||||||
|
comp0 = hn::Add(comp0, serr0);
|
||||||
|
comp1 = hn::Add(comp1, serr1);
|
||||||
|
comp2 = hn::Add(comp2, serr2);
|
||||||
|
comp3 = hn::Add(comp3, serr3);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>, HWY_IF_F32_D(DF)>
|
||||||
|
HWY_INLINE void Update1(DF df, const VF w0, const VF v0, VF& sum0,
|
||||||
|
VF& comp0) const {
|
||||||
|
VF serr0;
|
||||||
|
sum0 = TwoSums(df, sum0, w0, serr0);
|
||||||
|
|
||||||
|
comp0 = hn::Add(comp0, serr0);
|
||||||
|
}
|
||||||
|
|
||||||
|
template <class DF, class VF = hn::Vec<DF>>
|
||||||
|
HWY_INLINE float Reduce(DF df, VF& sum0, VF& sum1, VF& sum2, VF& sum3,
|
||||||
|
VF& comp0, VF& comp1, VF& comp2, VF& comp3) const {
|
||||||
|
// Reduction tree: sum of all accumulators by pairs, then across lanes.
|
||||||
|
AssimilateCascadedSums(df, sum1, comp1, sum0, comp0);
|
||||||
|
AssimilateCascadedSums(df, sum3, comp3, sum2, comp2);
|
||||||
|
AssimilateCascadedSums(df, sum2, comp2, sum0, comp0);
|
||||||
|
return ReduceCascadedSums(df, sum0, comp0);
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
const size_t mask_pos) {
|
const size_t mask_pos) {
|
||||||
|
PROFILER_FUNC;
|
||||||
HWY_DASSERT(size != 0);
|
HWY_DASSERT(size != 0);
|
||||||
HWY_DASSERT(mask_pos <= size);
|
HWY_DASSERT(mask_pos <= size);
|
||||||
|
|
||||||
|
|
@ -523,23 +570,22 @@ static HWY_NOINLINE void Softmax(float* HWY_RESTRICT x, const size_t size,
|
||||||
|
|
||||||
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
// Subtract max (avoid precision loss for large exponents) and exponentiate.
|
||||||
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
hn::Transform(d, x, mask_pos, [pmax](const auto d, const V value) HWY_ATTR {
|
||||||
#if HWY_TARGET & HWY_ALL_SVE
|
if constexpr (HWY_TARGET & HWY_ALL_SVE) {
|
||||||
// Temporary workaround for buggy SVE codegen: avoid inlined
|
// Temporary workaround for buggy SVE codegen: avoid inlined Exp().
|
||||||
// Exp().
|
return hn::CallExp(d, hn::Sub(value, *pmax));
|
||||||
return hn::CallExp(d, hn::Sub(value, *pmax));
|
} else {
|
||||||
#else
|
return hn::Exp(d, hn::Sub(value, *pmax));
|
||||||
return hn::Exp(d, hn::Sub(value, *pmax));
|
}
|
||||||
#endif
|
|
||||||
});
|
});
|
||||||
|
|
||||||
V sum = hn::Zero(d);
|
// Normalize to probability distribution. The exact sum seems like it should
|
||||||
V* psum = ∑
|
// not make a huge difference. It halves the standard deviation of the sum of
|
||||||
hn::Foreach(d, x, mask_pos, sum,
|
// the normalized probabilities from 1E-7 to 5E-8, but actually also changes
|
||||||
[psum](const auto d, const V value)
|
// the generated text after a few hundred tokens.
|
||||||
HWY_ATTR { *psum = hn::Add(*psum, value); });
|
const float sum_exp =
|
||||||
|
DecompressAndCall(d, MakeConstSpan(x, mask_pos), KernelCascadedSum());
|
||||||
// Normalize to probability distribution
|
// Double-precision reciprocal does not appear to affect the results.
|
||||||
const float mul = 1.0f / hn::ReduceSum(d, sum);
|
const float mul = 1.0f / sum_exp;
|
||||||
MulByConst(mul, x, size, mask_pos);
|
MulByConst(mul, x, size, mask_pos);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue