diff --git a/BUILD.bazel b/BUILD.bazel index 8c8fc27..3749297 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -44,9 +44,10 @@ cc_library( "ops/matmul.h", ], textual_hdrs = [ - "ops/ops-inl.h", + "ops/dot-inl.h", "ops/matmul-inl.h", "ops/matvec-inl.h", + "ops/ops-inl.h", ], deps = [ ":allocator", @@ -63,6 +64,30 @@ cc_library( ], ) +cc_test( + name = "dot_test", + size = "small", + timeout = "long", + srcs = ["ops/dot_test.cc"], + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":allocator", + ":common", + ":gemma_lib", + ":ops", + ":threading", + "@googletest//:gtest_main", # buildcleaner: keep + "//compression:compress", + "@hwy//:hwy", + "@hwy//:hwy_test_util", + "@hwy//:nanobenchmark", #buildcleaner: keep + "@hwy//:profiler", + "@hwy//:stats", + ], +) + cc_test( name = "ops_test", size = "small", diff --git a/CMakeLists.txt b/CMakeLists.txt index 72a9a9d..3bc4097 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -100,6 +100,7 @@ set(SOURCES gemma/tokenizer.h gemma/weights.cc gemma/weights.h + ops/dot-inl.h ops/matmul-inl.h ops/matvec-inl.h ops/ops-inl.h @@ -154,6 +155,7 @@ set(GEMMA_TEST_FILES backprop/backward_test.cc backprop/backward_scalar_test.cc backprop/optimize_test.cc + ops/dot_test.cc ops/ops_test.cc ops/matmul_test.cc ops/gemma_matvec_test.cc diff --git a/compression/BUILD b/compression/BUILD index 3334da8..ca1dd2e 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -105,6 +105,7 @@ cc_test( ":sfp", ":test_util", "@googletest//:gtest_main", + "//:ops", "@hwy//:hwy", "@hwy//:hwy_test_util", "@hwy//:nanobenchmark", @@ -158,7 +159,6 @@ cc_library( ":io", ":nuq", ":sfp", - "@hwy//:dot", "@hwy//:hwy", "@hwy//:stats", "@hwy//:thread_pool", diff --git a/compression/compress-inl.h b/compression/compress-inl.h index bfa9f1d..a246dfa 100644 --- a/compression/compress-inl.h +++ b/compression/compress-inl.h @@ -44,7 +44,6 @@ #include "compression/nuq-inl.h" #include "compression/sfp-inl.h" -#include "hwy/contrib/dot/dot-inl.h" #include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); @@ -111,30 +110,12 @@ struct CompressTraits { float* HWY_RESTRICT out, size_t num) { using VF = hn::Vec; const size_t N = hn::Lanes(df); - HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0); - for (size_t i = 0; i < num; i += 2 * N) { - VF in0, in1; - Decompress2(df, in, in_ofs + i, in0, in1); - hn::StoreU(in0, df, out + i); - hn::StoreU(in1, df, out + i + N); + for (size_t i = 0; i < num; i += N) { + const VF v = hn::LoadU(df, in + in_ofs + i); + hn::StoreU(v, df, out + i); } } - - // VecT can be float or hwy::bfloat16_t. - template - static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, - const MatT* HWY_RESTRICT in, size_t in_ofs, - const VecT* HWY_RESTRICT vec_aligned, - size_t num) { - HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - constexpr int kAssumptions = - hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; - // vec_aligned must be the second argument because hn::Dot supports f32*bf16 - // and f32*f32. - return hn::Dot::Compute(df, in + in_ofs, vec_aligned, num); - } }; template <> @@ -251,24 +232,6 @@ struct CompressTraits { } } - // VecT can be float or hwy::bfloat16_t. - template - static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, - const MatT* HWY_RESTRICT in, size_t in_ofs, - const VecT* HWY_RESTRICT vec_aligned, - size_t num) { - HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - - const hn::Repartition d_vec; - - constexpr int kAssumptions = - hn::Dot::kAtLeastOneVector | hn::Dot::kMultipleOfVector; - // vec_aligned must be first argument because hn::Dot supports f32*bf16 and - // bf16*bf16. - return hn::Dot::Compute(d_vec, vec_aligned, in + in_ofs, num); - } - // Computes the dot product of an even-odd deinterleaved, f32 `vec_aligned` // and a column- major matrix `in`. `vec_aligned` should be aligned and // alternate even-indexed `hn::Lanes(df32)` elements followed by odd-indexed @@ -359,30 +322,6 @@ struct CompressTraits { SfpCodec::Dec(d, in + in_ofs, num, out); } - template - static HWY_INLINE float Dot(DF df, size_t /*in_capacity*/, - const MatT* HWY_RESTRICT in, size_t in_ofs, - const VecT* HWY_RESTRICT vec_aligned, - size_t num) { - HWY_DASSERT(num >= hn::Lanes(df) && (num % hn::Lanes(df)) == 0); - HWY_DASSERT((in_ofs % hn::Lanes(df)) == 0); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - - using VF = hn::Vec; - VF sum0 = hn::Zero(df); - VF sum1 = hn::Zero(df); - VF sum2 = hn::Zero(df); - VF sum3 = hn::Zero(df); - - SfpCodec::Dot(df, in + in_ofs, num, vec_aligned, sum0, sum1, sum2, sum3); - - // Reduction tree: sum of all accumulators, then their lanes - sum0 = hn::Add(sum0, sum1); - sum2 = hn::Add(sum2, sum3); - sum0 = hn::Add(sum0, sum2); - return hn::ReduceSum(df, sum0); - } - // Computes the dot product of an even-odd deinterleaved, f32 or bf16 // `vec_aligned` and a column-major matrix `in`. `vec_aligned` should be // aligned and alternate even-indexed `hn::Lanes(df)` elements followed by @@ -446,25 +385,6 @@ struct CompressTraits { size_t in_ofs, OutT* out, size_t num) { NuqCodec::Dec(d, in_capacity, in, in_ofs, out, num); } - - template - static HWY_INLINE float Dot(DF df, size_t in_capacity, const MatT* in, - size_t in_ofs, - const VecT* HWY_RESTRICT vec_aligned, - size_t num) { - using VF = hn::Vec; - VF sum0 = hn::Zero(df); - VF sum1 = hn::Zero(df); - VF sum2 = hn::Zero(df); - VF sum3 = hn::Zero(df); - - NuqCodec::Dot(df, in_capacity, in, in_ofs, vec_aligned, num, sum0, sum1, - sum2, sum3); - - // Reduction tree: sum of all accumulators, then their lanes - sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); - return hn::ReduceSum(df, sum0); - } }; // Compresses `num` inputs to `out` starting at `out_ofs`. This can be used for @@ -559,35 +479,6 @@ HWY_INLINE void Decompress(const CompressedArray& compressed, fprintf(stderr, "Decompress %.1f MB/s\n", mbps); } -// Returns dot product with `vec_aligned` of length `num`. -template -HWY_INLINE float Dot(DF df, const std::array& w, size_t ofs, - const VecT* x, size_t num) { - HWY_DASSERT(ofs + num <= kCapacity); - HWY_DASSERT(hn::IsAligned(df, x)); - using Traits = CompressTraits; - return Traits::Dot(df, w.size(), w.data(), ofs, x, num); -} - -// Returns dot product with `vec_aligned` of length `num`. -template -HWY_INLINE float Dot(DF df, const CompressedArray& compressed, - size_t compressed_ofs, const VecT* vec_aligned, - size_t num) { - HWY_DASSERT(compressed_ofs + num <= compressed.size()); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - using Traits = CompressTraits; - float dot_result; - if constexpr (kVecEO) { - dot_result = Traits::DotEO(df, compressed.data(), compressed_ofs, - vec_aligned, num); - } else { - dot_result = Traits::Dot(df, compressed.size(), compressed.data(), - compressed_ofs, vec_aligned, num); - } - return compressed.scale() * dot_result; -} - // Functor called for each tensor, which compresses and stores them along with // their scaling factors to BlobStore. class Compressor { diff --git a/compression/distortion.h b/compression/distortion.h index cbcc35c..c259ed4 100644 --- a/compression/distortion.h +++ b/compression/distortion.h @@ -33,7 +33,7 @@ namespace gcpp { // despite floating-point rounding. `sum` is already the best estimate, so do // not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71], // this does not require any relative ordering of the exponents of a and b. -template +template static inline T TwoSum(T a, T b, T& err) { const T sum = a + b; const T a2 = sum - b; diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index ea5fb4c..3a40527 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -52,6 +52,9 @@ HWY_INLINE hn::Mask SignedLt(DU du, hn::Vec a, hn::Vec b) { return SignedGt(du, b, a); } +// Saturated subtraction; returns 0 if the result would be negative. +static inline size_t SubOr0(size_t a, size_t b) { return a > b ? a - b : 0; } + // Encode/decode functions. class SfpCodec { public: @@ -339,12 +342,10 @@ class SfpCodec { HWY_DASSERT(remaining < 2 * N16); if (remaining != 0) { const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); - HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; VBF bf0, bf1; Dec2B(dbf, packed, bf0, bf1); - hn::StoreU(bf0, dbf, padded); - hn::StoreU(bf1, dbf, padded + N16); - hwy::CopyBytes(padded, out_bf + i, remaining * sizeof(padded[0])); + hn::StoreN(bf0, dbf, out_bf + i, remaining); + hn::StoreN(bf1, dbf, out_bf + i + N16, SubOr0(remaining, N16)); } } @@ -375,104 +376,12 @@ class SfpCodec { HWY_DASSERT(remaining < 4 * NF); if (remaining != 0) { const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); - HWY_ALIGN float padded[4 * hn::MaxLanes(df)]; VF f0, f1, f2, f3; Dec4F(df, packed, f0, f1, f2, f3); - hn::StoreU(f0, df, padded + NF * 0); - hn::StoreU(f1, df, padded + NF * 1); - hn::StoreU(f2, df, padded + NF * 2); - hn::StoreU(f3, df, padded + NF * 3); - hwy::CopyBytes(padded, out_f + i, remaining * sizeof(padded[0])); - } - } - - // Fused decode and dot product with bf16 into four output accumulators. - template - static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed, - size_t num, - const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, - hn::Vec& sum0, hn::Vec& sum1, - hn::Vec& sum2, hn::Vec& sum3) { - const hn::Repartition d8; - const hn::Repartition dbf; - using V8 = hn::Vec; - using VBF = hn::Vec; - const size_t N16 = hn::Lanes(dbf); - - size_t i = 0; - if (num >= 2 * N16) { - HWY_UNROLL(1) - for (; i <= num - 2 * N16; i += 2 * N16) { - const V8 packed = hn::LoadU(d8, &in_packed->byte + i); - const VBF v0 = hn::LoadU(dbf, vec_aligned + i); - const VBF v1 = hn::LoadU(dbf, vec_aligned + i + N16); - VBF bf0, bf1; - Dec2B(dbf, packed, bf0, bf1); - sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1); - sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3); - } - } - - const size_t remaining = num - i; - if (remaining != 0) { - const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); - HWY_ALIGN hwy::bfloat16_t padded[2 * hn::MaxLanes(dbf)]; - hwy::ZeroBytes(padded, sizeof(padded)); - hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0])); - const VBF v0 = hn::LoadU(dbf, padded); - const VBF v1 = hn::LoadU(dbf, padded + N16); - VBF bf0, bf1; - Dec2B(dbf, packed, bf0, bf1); - sum0 = hn::ReorderWidenMulAccumulate(df, bf0, v0, sum0, sum1); - sum2 = hn::ReorderWidenMulAccumulate(df, bf1, v1, sum2, sum3); - } - } - - // Fused decode and dot product with f32 into four output accumulators. - template - static HWY_INLINE void Dot(DF df, const SfpStream* HWY_RESTRICT in_packed, - size_t num, const float* HWY_RESTRICT vec_aligned, - hn::Vec& sum0, hn::Vec& sum1, - hn::Vec& sum2, hn::Vec& sum3) { - const hn::Repartition d8; - using V8 = hn::Vec; - using VF = hn::Vec; - const size_t NF = hn::Lanes(df); - - size_t i = 0; - if (num >= 4 * NF) { - HWY_UNROLL(1) - for (; i <= num - 4 * NF; i += 4 * NF) { - const V8 packed = hn::LoadU(d8, &in_packed->byte + i); - const VF v0 = hn::LoadU(df, vec_aligned + i + NF * 0); - const VF v1 = hn::LoadU(df, vec_aligned + i + NF * 1); - const VF v2 = hn::LoadU(df, vec_aligned + i + NF * 2); - const VF v3 = hn::LoadU(df, vec_aligned + i + NF * 3); - VF f0, f1, f2, f3; - Dec4F(df, packed, f0, f1, f2, f3); - sum0 = hn::MulAdd(f0, v0, sum0); - sum1 = hn::MulAdd(f1, v1, sum1); - sum2 = hn::MulAdd(f2, v2, sum2); - sum3 = hn::MulAdd(f3, v3, sum3); - } - } - - const size_t remaining = num - i; - if (remaining != 0) { - const V8 packed = hn::LoadN(d8, &in_packed->byte + i, remaining); - HWY_ALIGN float padded[4 * hn::MaxLanes(df)]; - hwy::ZeroBytes(padded, sizeof(padded)); - hwy::CopyBytes(vec_aligned + i, padded, remaining * sizeof(padded[0])); - const VF v0 = hn::LoadU(df, padded + NF * 0); - const VF v1 = hn::LoadU(df, padded + NF * 1); - const VF v2 = hn::LoadU(df, padded + NF * 2); - const VF v3 = hn::LoadU(df, padded + NF * 3); - VF f0, f1, f2, f3; - Dec4F(df, packed, f0, f1, f2, f3); - sum0 = hn::MulAdd(f0, v0, sum0); - sum1 = hn::MulAdd(f1, v1, sum1); - sum2 = hn::MulAdd(f2, v2, sum2); - sum3 = hn::MulAdd(f3, v3, sum3); + hn::StoreN(f0, df, out_f + i + 0 * NF, remaining); + hn::StoreN(f1, df, out_f + i + 1 * NF, SubOr0(remaining, 1 * NF)); + hn::StoreN(f2, df, out_f + i + 2 * NF, SubOr0(remaining, 2 * NF)); + hn::StoreN(f3, df, out_f + i + 3 * NF, SubOr0(remaining, 3 * NF)); } } diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 68f3e14..41983d3 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -38,6 +38,7 @@ #include "hwy/foreach_target.h" // IWYU pragma: keep // Any highway.h must come after foreach_target.h #include "compression/sfp-inl.h" +#include "ops/dot-inl.h" #include "hwy/highway.h" #include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util-inl.h" @@ -464,16 +465,10 @@ struct TestDot { double elapsed_eo = hwy::HighestValue(); for (size_t rep = 0; rep < 200; ++rep) { { - hn::Vec sum0 = hn::Zero(df); - hn::Vec sum1 = hn::Zero(df); - hn::Vec sum2 = hn::Zero(df); - hn::Vec sum3 = hn::Zero(df); const double t0 = hwy::platform::Now(); - SfpCodec::Dot(df, sfp.get(), num, vec.get(), sum0, sum1, sum2, sum3); + actual = SimpleDot(df, sfp.get(), 0, vec.get(), num); const double t1 = hwy::platform::Now(); elapsed = HWY_MIN(elapsed, t1 - t0); - sum0 = hn::Add(hn::Add(sum0, sum1), hn::Add(sum2, sum3)); - actual = hn::ReduceSum(df, sum0); } { hn::Vec sum0 = hn::Zero(df); diff --git a/ops/dot-inl.h b/ops/dot-inl.h new file mode 100644 index 0000000..7addc1b --- /dev/null +++ b/ops/dot-inl.h @@ -0,0 +1,377 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include // std::sort +#include +#include // std::abs + +#include "compression/compress.h" +#include "compression/distortion.h" // TwoSum +#include "hwy/base.h" + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_DOT_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_DOT_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_DOT_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_DOT_TOGGLE +#endif + +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "hwy/contrib/math/math-inl.h" +#include "hwy/profiler.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +// Returns dot product of `x` and `w`, both length `num`. Uses Decompress2 to +// convert WeightT and VecT to float, then FMA. +// TODO: improve precision? +// TODO: use bf16 products? +template +HWY_INLINE float SimpleDot(DF df, const WeightT* HWY_RESTRICT w, size_t w_ofs, + const VecT* HWY_RESTRICT x, size_t num) { + PROFILER_FUNC; + const size_t N = hn::Lanes(df); + HWY_DASSERT(hn::IsAligned(df, x)); + using VF = hn::Vec; + using TraitsW = CompressTraits; + using TraitsV = CompressTraits; + + VF sum0 = hn::Zero(df); + VF sum1 = hn::Zero(df); + VF sum2 = hn::Zero(df); + VF sum3 = hn::Zero(df); + + VF w0, w1, w2, w3, v0, v1, v2, v3; // decompressed inputs + + size_t i = 0; + if (num >= 4 * N) { + for (; i <= num - 4 * N; i += 4 * N) { + TraitsW::Decompress2(df, w, w_ofs + i, w0, w1); + TraitsW::Decompress2(df, w, w_ofs + i + 2 * N, w2, w3); + TraitsV::Decompress2(df, x, i, v0, v1); + TraitsV::Decompress2(df, x, i + 2 * N, v2, v3); + + sum0 = hn::MulAdd(w0, v0, sum0); + sum1 = hn::MulAdd(w1, v1, sum1); + sum2 = hn::MulAdd(w2, v2, sum2); + sum3 = hn::MulAdd(w3, v3, sum3); + } + } + + const size_t remaining = num - i; + if (HWY_UNLIKELY(remaining != 0)) { + HWY_ALIGN float padded_w[4 * hn::MaxLanes(df)] = {}; + HWY_ALIGN float padded_x[4 * hn::MaxLanes(df)] = {}; + // The actual capacity of w[] is unknown, so pass a lower bound. + const size_t w_capacity = w_ofs + num; + TraitsW::Decompress(df, w_capacity, w, w_ofs + i, padded_w, remaining); + TraitsV::Decompress(df, num, x, i, padded_x, remaining); + const size_t padding = 4 * N - remaining; + hwy::ZeroBytes(padded_w + remaining, padding * sizeof(padded_w[0])); + hwy::ZeroBytes(padded_x + remaining, padding * sizeof(padded_x[0])); + for (; i < num; i += N) { + const VF w0 = hn::Load(df, padded_w + i); + const VF v0 = hn::Load(df, padded_x + i); + sum0 = hn::MulAdd(w0, v0, sum0); + } + } + + // Reduction tree: sum of all accumulators by pairs, then across lanes. + sum0 = hn::Add(sum0, sum1); + sum2 = hn::Add(sum2, sum3); + sum0 = hn::Add(sum0, sum2); + return hn::ReduceSum(df, sum0); +} + +// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. +template +HWY_INLINE float Dot(DF df, const std::array& w, size_t ofs, + const VecT* vec_aligned, size_t num) { + PROFILER_ZONE("Dot array"); + HWY_DASSERT(ofs + num <= kCapacity); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + return SimpleDot(df, w.data(), ofs, vec_aligned, num); +} + +// Adapter for use by matvec-inl.h. TODO: remove when that is no longer used. +template +HWY_INLINE float Dot(DF df, const CompressedArray& compressed, + size_t compressed_ofs, const VecT* vec_aligned, + size_t num) { + PROFILER_ZONE("Dot CompressedArray"); + HWY_DASSERT(compressed_ofs + num <= compressed.size()); + HWY_DASSERT(hn::IsAligned(df, vec_aligned)); + using Traits = CompressTraits; + float dot_result; + if constexpr (kVecEO) { + dot_result = + Traits::DotEO(df, compressed.data(), compressed_ofs, vec_aligned, num); + } else { + dot_result = + SimpleDot(df, compressed.data(), compressed_ofs, vec_aligned, num); + } + return compressed.scale() * dot_result; +} + +// Returns result accurate to 1.5 ulp, assuming `num` < 2^(52-23), no overflow, +// and round to nearest. See "Accurate and efficient floating point summation". +HWY_INLINE float ExactDot(const float* HWY_RESTRICT a, + const float* HWY_RESTRICT b, size_t num, + double* HWY_RESTRICT buf) { + PROFILER_FUNC; + for (size_t i = 0; i < num; ++i) { + buf[i] = static_cast(a[i]) * static_cast(b[i]); + } + // Sort by decreasing magnitude (not supported by VQSort). + std::sort(buf, buf + num, + [](double a, double b) { return std::abs(a) > std::abs(b); }); + double sum = 0.0; + for (size_t i = 0; i < num; ++i) { + sum += buf[i]; + } + return static_cast(sum); +} + +//------------------------------------------------------------------------------ +// Cascaded summation (twice working precision) + +// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`, +// despite floating-point rounding. `sum` is already the best estimate for the +// addition, so do not actually add `err` to it. `UpdateCascadedSums` instead +// accumulates multiple `err`, which are then later added to `sum`. +// +// Knuth98/Moller65. Unlike Fast2Sum [Dekker71], this does not require any +// relative ordering of the exponents of a and b. +template > +static HWY_INLINE VF TwoSums(DF /*df*/, VF a, VF b, VF& err) { + const VF sum = hn::Add(a, b); + const VF a2 = hn::Sub(sum, b); + const VF b2 = hn::Sub(sum, a2); + const VF err_a = hn::Sub(a, a2); + const VF err_b = hn::Sub(b, b2); + err = hn::Add(err_a, err_b); + return sum; +} + +// Adds vectors with about twice the precision of VF using 7 FLOPS. +// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic. +// `sum` and `sum_err` must be initially zero. +// +// Each lane is an independent cascaded sum. To obtain a single result, use +// `ReduceCascadedSum`. Vectors generally cannot be wrapped in a class, hence we +// use free functions. +template > +void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) { + VF err; + sum = TwoSums(df, sum, v, err); + sum_err += err; +} + +// Combines two cascaded sum vectors, typically from unrolling/parallelization. +template > +void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err, + VF& sum, VF& sum_err) { + UpdateCascadedSums(df, other_sum, sum, sum_err); + sum_err += other_sum_err; +} + +// Reduces cascaded sums, to a single value. Slow, call outside of loops. +template > +hn::TFromD ReduceCascadedSums(DF df, const VF sum, VF sum_err) { + const size_t N = hn::Lanes(df); + using TF = hn::TFromD; + TF total = TF{0.0}; + TF total_err = TF{0.0}; + for (size_t i = 0; i < N; ++i) { + TF err; + total = TwoSum(total, hn::ExtractLane(sum, i), err); + total_err += hn::ExtractLane(sum_err, i); + total_err += err; + } + return total + total_err; +} + +//------------------------------------------------------------------------------ + +// Returns 2 * sum(|f|) / |sum(f)|. This is large when there are many +// similar-magnitude and opposite-sign elements in `f`. See +// https://en.wikipedia.org/wiki/Condition_number. +template > +static inline double ConditionNumber(DF df, const float* HWY_RESTRICT f, + size_t num) { + PROFILER_FUNC; + const size_t N = hn::Lanes(df); + + VF sum = hn::Zero(df); + VF sum_err = hn::Zero(df); + VF sum_abs = hn::Zero(df); + VF sum_err_abs = hn::Zero(df); + + size_t i = 0; + if (num >= N) { + for (; i <= num - N; i += N) { + const VF v = hn::Load(df, f + i); + UpdateCascadedSums(v, sum, sum_err); + UpdateCascadedSums(hn::Abs(v), sum_abs, sum_err_abs); + } + } + const size_t remaining = num - i; + if (remaining != 0) { + const VF v = hn::LoadN(df, f + i, remaining); + UpdateCascadedSums(v, sum, sum_err); + UpdateCascadedSums(hn::Abs(v), sum_abs, sum_err_abs); + } + + const float div = std::abs(ReduceCascadedSums(df, sum, sum_err)); + if (div == 0.0f) return hwy::HighestValue(); + const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_err_abs) / + static_cast(div); + HWY_ASSERT(cond >= 0.0); + return cond; +} + +// Same, but for dot product of two arrays. +// TODO: move into dot_test. +template > +static inline double ConditionNumber(DF df, const float* HWY_RESTRICT a, + const float* HWY_RESTRICT b, size_t num) { + PROFILER_FUNC; + const size_t N = hn::Lanes(df); + + VF sum = hn::Zero(df); + VF sum_err = hn::Zero(df); + VF sum_abs = hn::Zero(df); + VF sum_err_abs = hn::Zero(df); + + size_t i = 0; + if (num >= N) { + for (; i <= num - N; i += N) { + const VF va = hn::Load(df, a + i); + const VF vb = hn::Load(df, b + i); + const VF mul = hn::Mul(va, vb); + UpdateCascadedSums(df, mul, sum, sum_err); + UpdateCascadedSums(df, hn::Abs(mul), sum_abs, sum_err_abs); + } + } + const size_t remaining = num - i; + if (remaining != 0) { + const VF va = hn::LoadN(df, a + i, remaining); + const VF vb = hn::LoadN(df, b + i, remaining); + const VF mul = hn::Mul(va, vb); + UpdateCascadedSums(df, mul, sum, sum_err); + UpdateCascadedSums(df, hn::Abs(mul), sum_abs, sum_err_abs); + } + + const float div = std::abs(ReduceCascadedSums(df, sum, sum_err)); + if (div == 0.0f) return hn::GetLane(hn::Inf(df)); + const double cond = 2.0 * ReduceCascadedSums(df, sum_abs, sum_err_abs) / + static_cast(div); + HWY_ASSERT(cond >= 0.0); + return cond; +} + +//------------------------------------------------------------------------------ +// Compensated dot product + +#if !HWY_NATIVE_FMA + +// Returns non-overlapping `x` and `y` such that `x + y` = `f` and |x| >= |y|. +// Notation from Algorithm 3.1 in Handbook of Floating-Point Arithmetic. 4 ops. +template > +static HWY_INLINE void VeltkampSplit(DF df, VF a, VF& x, VF& y) { + using TF = hn::TFromD; + constexpr int t = hwy::MantissaBits() + 1; // = -log2(epsilon) + constexpr int s = hwy::DivCeil(t, 2); + const VF factor = hn::Set(df, hwy::ConvertScalarTo((1ULL << s) + 1)); + const VF c = hn::Mul(factor, a); + x = hn::Sub(c, hn::Sub(c, a)); + y = hn::Sub(a, x); +} + +#endif // !HWY_NATIVE_FMA + +// Returns `prod` and `err` such that `prod + err` is exactly equal to `a * b`, +// despite floating-point rounding, assuming that `err` is not subnormal, i.e., +// the sum of exponents >= min exponent + mantissa bits. 2..17 ops. +template > +static HWY_INLINE VF TwoProducts(DF df, VF a, VF b, VF& err) { + const VF prod = hn::Mul(a, b); +#if HWY_NATIVE_FMA + err = hn::MulSub(a, b, prod); +#else + VF a1, a2, b1, b2; + VeltkampSplit(df, a, a1, a2); + VeltkampSplit(df, b, b1, b2); + const VF m = hn::Sub(prod, hn::Mul(a1, b1)); + const VF n = hn::Sub(m, hn::Mul(a2, b1)); + const VF o = hn::Sub(n, hn::Mul(a1, b2)); + err = hn::Sub(hn::Mul(a2, b2), o); +#endif + return prod; +} + +// Algorithm 6.15 from Handbook of Floating-Point Arithmetic. +template +HWY_INLINE float CompensatedDot(DF df, const WeightT* HWY_RESTRICT w, + size_t w_ofs, const VecT* HWY_RESTRICT x, + size_t num) { + PROFILER_FUNC; + const size_t N = hn::Lanes(df); + HWY_ASSERT((num % (2 * N)) == 0); + HWY_DASSERT(hn::IsAligned(df, x)); + using VF = hn::Vec; + using TraitsW = CompressTraits; + using TraitsV = CompressTraits; + + VF sum0 = hn::Zero(df); + VF sum1 = hn::Zero(df); + VF sum_err0 = hn::Zero(df); + VF sum_err1 = hn::Zero(df); + + VF w0, w1, v0, v1; // decompressed inputs + VF perr0, perr1, serr0, serr1; // output arg of TwoProducts/TwoSums + + for (size_t i = 0; i < num; i += 2 * N) { + TraitsW::Decompress2(df, w, w_ofs + i, w0, w1); + TraitsV::Decompress2(df, x, i, v0, v1); + + const VF prod0 = TwoProducts(df, w0, v0, perr0); + const VF prod1 = TwoProducts(df, w1, v1, perr1); + + sum0 = TwoSums(df, prod0, sum0, serr0); + sum1 = TwoSums(df, prod1, sum1, serr1); + + sum_err0 += perr0 + serr0; + sum_err1 += perr1 + serr1; + } + + AssimilateCascadedSums(df, sum1, sum_err1, sum0, sum_err0); + return ReduceCascadedSums(df, sum0, sum_err0); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/ops/dot_test.cc b/ops/dot_test.cc new file mode 100644 index 0000000..107b67b --- /dev/null +++ b/ops/dot_test.cc @@ -0,0 +1,224 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef HWY_DISABLED_TARGETS +// Exclude HWY_SCALAR due to 2x bf16 -> f32. +#define HWY_DISABLED_TARGETS HWY_SCALAR +#endif + +#include +#include + +#include // std::swap +#include +#include +#include + +#include "util/allocator.h" +#include "util/threading.h" +#include "hwy/aligned_allocator.h" +#include "hwy/base.h" +#include "hwy/stats.h" +#include "hwy/timer.h" + +// clang-format off +#undef HWY_TARGET_INCLUDE +#define HWY_TARGET_INCLUDE "ops/dot_test.cc" +// clang-format on +#include "hwy/foreach_target.h" // IWYU pragma: keep +#include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" +#include "ops/dot-inl.h" +#include "hwy/profiler.h" +#include "hwy/tests/test_util-inl.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +using Array = hwy::AlignedFreeUniquePtr; + +// Returns normalized value in [-1, 1). +float RandomFloat(std::mt19937& rng) { + const uint32_t exp = hwy::BitCastScalar(1.0f); + const uint32_t mantissa_mask = hwy::MantissaMask(); + const uint32_t representation = exp | (rng() & mantissa_mask); + const float f12 = hwy::BitCastScalar(representation); + HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa + const float f = (2.0f * (f12 - 1.0f)) - 1.0f; + HWY_DASSERT(-1.0f <= f && f < 1.0f); + return f; +} + +// Based on Algorithm 6.1 from "Accurate Sum and Dot Product". +// `num` is the size of a, b[, and buf] and must be larger than 2 and even. +void GenerateIllConditionedInputs(double target_cond, size_t num, + float* HWY_RESTRICT a, float* HWY_RESTRICT b, + double* HWY_RESTRICT buf, std::mt19937& rng) { + PROFILER_FUNC; + HWY_ASSERT(target_cond >= 1.0); + HWY_ASSERT(num % 2 == 0); + const size_t half = num / 2; + const hn::ScalableTag df; + + const int max_exp = static_cast(std::log2(target_cond) / 2.0); + std::uniform_int_distribution e_dist(0, max_exp); + + // First half: random exponents and mantissas + for (size_t i = 0; i < half; ++i) { + // Ensure the min and max exponents are used. + const int e = i == 0 ? 0 : i == 1 ? max_exp : e_dist(rng); + a[i] = RandomFloat(rng) * (1 << e); + b[i] = RandomFloat(rng) * (1 << e); + } + + // Zero-init second half for DotExact + for (size_t i = half; i < num; ++i) { + a[i] = 0.0f; + b[i] = 0.0f; + } + + const float a_exp_step = max_exp / (half - 1); + float a_exp = max_exp; // max_exp downto 0 + for (size_t i = half; i < num; ++i, a_exp -= a_exp_step) { + const int e = static_cast(a_exp); + HWY_DASSERT(e >= 0); + a[i] = RandomFloat(rng) * (1 << e); + const float r = RandomFloat(rng) * (1 << e); + if (a[i] == 0.0f) { + b[i] = 0.0f; + } else { + // This is called >100K times. CompensatedDot is much faster than ExactDot + // and just about as accurate, but requires multiples of two vectors. + // const float exact = ExactDot(a, b, i, buf); + (void)buf; + const size_t padded = hwy::RoundUpTo(i, 2 * hn::Lanes(df)); + const float exact = CompensatedDot(df, a, /*w_ofs=*/0, b, padded); + b[i] = r - exact / a[i]; + } + } + + // Fisher-Yates shuffle of both a and b simultaneously - std::shuffle only + // shuffles one array, and we want the same permutation for both. + for (size_t i = num - 1; i != 0; --i) { + std::uniform_int_distribution dist(0, i); + const size_t j = dist(rng); + + std::swap(a[i], a[j]); + std::swap(b[i], b[j]); + } +} + +template +void PrintStats(const char* caption, const std::array& values) { + hwy::Stats stats; + for (T t : values) { + stats.Notify(static_cast(t)); + } + fprintf(stderr, "%s %s\n", caption, stats.ToString().c_str()); +} + +void TestAllDot() { + // Skip EMU128 and old x86, include SSE4 because it tests the non-FMA path. + if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSSE3 || + HWY_TARGET == HWY_SSE2) { + return; + } + + hn::ScalableTag df; + + constexpr size_t kMaxThreads = 8; + std::mt19937 rngs[kMaxThreads]; + for (size_t i = 0; i < kMaxThreads; ++i) { + rngs[i].seed(12345 + 65537 * i); + } + + constexpr size_t kReps = hn::AdjustedReps(200); + const size_t num = 24 * 1024; + PerClusterPools pools(/*max_clusters=*/1, kMaxThreads, /*pin=*/1); + RowVectorBatch a(kMaxThreads, num); + RowVectorBatch b(kMaxThreads, num); + RowVectorBatch bufs(kMaxThreads, num); + + const double target_cond = 1e12; + std::array conds; + std::array ulps_fast; + std::array ulps_comp; + std::array t_fast; + std::array t_comp; + + constexpr size_t kTimeReps = 3; + + pools.Inner(0).Run(0, kReps, [&](const uint32_t rep, size_t thread) { + float* HWY_RESTRICT pa = a.Batch(thread); + float* HWY_RESTRICT pb = b.Batch(thread); + double* HWY_RESTRICT buf = bufs.Batch(thread); + GenerateIllConditionedInputs(target_cond, num, pa, pb, buf, rngs[thread]); + conds[rep] = ConditionNumber(df, pa, pb, num); + + const float dot_exact = ExactDot(pa, pb, num, buf); + + float dot_fast = 0.0f; + float dot_comp = 0.0f; + + double elapsed = hwy::HighestValue(); + for (int rep = 0; rep < kTimeReps; ++rep) { + const double start = hwy::platform::Now(); + dot_fast += SimpleDot(df, pa, 0, pb, num); + elapsed = HWY_MIN(elapsed, hwy::platform::Now() - start); + } + dot_fast /= kTimeReps; + t_fast[rep] = elapsed; + + elapsed = hwy::HighestValue(); + for (size_t r = 0; r < kTimeReps; ++r) { + const double start = hwy::platform::Now(); + dot_comp += CompensatedDot(df, pa, /*w_ofs=*/0, pb, num); + elapsed = HWY_MIN(elapsed, hwy::platform::Now() - start); + } + dot_comp /= kTimeReps; + t_comp[rep] = elapsed; + + ulps_fast[rep] = hwy::detail::ComputeUlpDelta(dot_fast, dot_exact); + ulps_comp[rep] = hwy::detail::ComputeUlpDelta(dot_comp, dot_exact); + fprintf(stderr, "cond %.1E: %15.7E %15.7E %15.7E ulp %5u %1u\n", conds[rep], + dot_exact, dot_fast, dot_comp, ulps_fast[rep], ulps_comp[rep]); + }); + + PROFILER_PRINT_RESULTS(); + PrintStats("cond", conds); + PrintStats("ulp fast", ulps_fast); + PrintStats("ulp comp", ulps_comp); + PrintStats("t fast", t_fast); + PrintStats("t comp", t_comp); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#if HWY_ONCE + +namespace gcpp { +HWY_BEFORE_TEST(DotTest); +HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot); +HWY_AFTER_TEST(); + +} // namespace gcpp + +#endif diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 70d6b87..3d6bbcd 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -38,6 +38,7 @@ #endif #include "compression/compress-inl.h" +#include "ops/dot-inl.h" #include "hwy/contrib/dot/dot-inl.h" #include "hwy/contrib/math/math-inl.h" #include "hwy/contrib/matvec/matvec-inl.h"