From c29e9752c7c57a51ad289ba70e6b90f1a124c977 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Wed, 4 Sep 2024 09:24:39 -0700 Subject: [PATCH] Refactor/cleanup, remove even_odd * New compression/shared.h, remove sfp.h * Remove unused DistortionStats b_l1_ * Move exact arithmetic functions into fp_arith * Remove even_odd optimization for MatVec (mostly unused) * use BF16 typedef more widely * Add kMaxSFP constant PiperOrigin-RevId: 670996386 --- BUILD.bazel | 10 ++ CMakeLists.txt | 6 +- backprop/backward-inl.h | 1 + backprop/forward-inl.h | 15 +- compression/BUILD | 44 +++++- compression/compress.h | 41 ++--- compression/compress_test.cc | 14 ++ compression/compress_weights.cc | 11 +- compression/distortion.h | 14 +- compression/distortion_test.cc | 5 +- compression/nuq-inl.h | 2 +- compression/nuq.h | 16 +- compression/python/compression_clif_aux.cc | 6 +- compression/sfp-inl.h | 2 +- compression/sfp.h | 51 ------- compression/sfp_test.cc | 9 +- compression/shared.h | 94 ++++++++++++ gemma/activations.h | 8 - gemma/gemma-inl.h | 7 +- ops/fp_arith-inl.h | 152 +++++++++++++++++++ ops/gemma_matvec_test.cc | 8 +- ops/matvec-inl.h | 168 ++++++--------------- 22 files changed, 423 insertions(+), 261 deletions(-) create mode 100644 compression/compress_test.cc delete mode 100644 compression/sfp.h create mode 100644 compression/shared.h create mode 100644 ops/fp_arith-inl.h diff --git a/BUILD.bazel b/BUILD.bazel index d672a38..2fc9e60 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -48,6 +48,15 @@ cc_library( ], ) +# Avoids circular dependency: fp_arith-inl -> compress-inl -> ops-inl +cc_library( + name = "fp_arith", + textual_hdrs = ["ops/fp_arith-inl.h"], + deps = [ + "@hwy//:hwy", + ], +) + cc_library( name = "ops", hdrs = [ @@ -61,6 +70,7 @@ cc_library( ], deps = [ ":allocator", + ":fp_arith", ":threading", "//compression:compress", "//compression:sfp", diff --git a/CMakeLists.txt b/CMakeLists.txt index 3ed8397..84832ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -46,8 +46,8 @@ set(SOURCES compression/io.h compression/nuq.h compression/nuq-inl.h - compression/sfp.h compression/sfp-inl.h + compression/shared.h compression/weights_raw.h backprop/activations.h backprop/backward.cc @@ -155,6 +155,10 @@ set(GEMMA_TEST_FILES backprop/backward_test.cc backprop/backward_scalar_test.cc backprop/optimize_test.cc + compression/compress_test.cc + compression/distortion_test.cc + compression/sfp_test.cc + compression/nuq_test.cc ops/dot_test.cc ops/ops_test.cc ops/matmul_test.cc diff --git a/backprop/backward-inl.h b/backprop/backward-inl.h index 39b03f7..76bd87e 100644 --- a/backprop/backward-inl.h +++ b/backprop/backward-inl.h @@ -28,6 +28,7 @@ #include "backprop/activations.h" #include "backprop/prompt.h" #include "gemma/common.h" +#include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" diff --git a/backprop/forward-inl.h b/backprop/forward-inl.h index 7dec634..838b042 100644 --- a/backprop/forward-inl.h +++ b/backprop/forward-inl.h @@ -24,9 +24,9 @@ #include #include "backprop/activations.h" -#include "gemma/activations.h" #include "gemma/common.h" #include "gemma/configs.h" +#include "util/allocator.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -40,9 +40,10 @@ #define THIRD_PARTY_GEMMA_CPP_FORWARD_TOGGLE #endif +#include "hwy/highway.h" +// After highway.h #include "ops/matvec-inl.h" #include "ops/ops-inl.h" -#include "hwy/highway.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { @@ -110,7 +111,7 @@ void ApplyForwardLayer(const LayerT& weights, for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec<(kHeads + 2) * kQKVDim, kModelDim>( weights.qkv_einsum_w, 0, - activations.pre_att_rms_out.data() + pos * kModelDim, nullptr, + activations.pre_att_rms_out.data() + pos * kModelDim, activations.qkv.data() + pos * (kHeads + 2) * kQKVDim, pool); } const size_t num_tasks = kHeads * num_tokens; @@ -174,7 +175,7 @@ void ApplyForwardLayer(const LayerT& weights, MatVec( weights.attn_vec_einsum_w, head * kModelDim * kQKVDim, activations.att_out.data() + pos * kHeads * kQKVDim + head * kQKVDim, - nullptr, activations.att_post1.data() + pos * kModelDim, pool); + activations.att_post1.data() + pos * kModelDim, pool); AddFrom(activations.att_post1.data() + pos * kModelDim, activations.attention_out.data() + pos * kModelDim, kModelDim); } @@ -192,7 +193,7 @@ void ApplyForwardLayer(const LayerT& weights, for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec( weights.gating_einsum_w, 0, - activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, nullptr, + activations.bf_pre_ffw_rms_out.data() + pos * kModelDim, activations.ffw_hidden.data() + pos * kFFHiddenDim * 2, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { @@ -215,7 +216,7 @@ void ApplyForwardLayer(const LayerT& weights, MatVec( weights.linear_w, 0, activations.ffw_hidden_gated.data() + pos * kFFHiddenDim, - nullptr, output + pos * kModelDim, pool); + output + pos * kModelDim, pool); } for (size_t pos = 0; pos < num_tokens; ++pos) { AddFrom(activations.attention_out.data() + pos * kModelDim, @@ -265,7 +266,7 @@ float CrossEntropyLossForwardPass(const std::vector& prompt, for (size_t pos = 0; pos < num_tokens; ++pos) { MatVec( weights.embedder_input_embedding, 0, - forward.final_norm_output.data() + pos * kModelDim, nullptr, + forward.final_norm_output.data() + pos * kModelDim, forward.logits.data() + pos * kVocabSize, pool); } diff --git a/compression/BUILD b/compression/BUILD index 188a8b2..a437a17 100644 --- a/compression/BUILD +++ b/compression/BUILD @@ -64,6 +64,7 @@ cc_test( srcs = ["distortion_test.cc"], deps = [ ":distortion", + ":shared", "@googletest//:gtest_main", "//:test_util", "@hwy//:hwy", @@ -72,11 +73,19 @@ cc_test( ], ) +cc_library( + name = "shared", + hdrs = ["shared.h"], + deps = [ + "@hwy//:hwy", + ], +) + cc_library( name = "sfp", - hdrs = ["sfp.h"], textual_hdrs = ["sfp-inl.h"], deps = [ + ":shared", "@hwy//:hwy", ], ) @@ -93,6 +102,7 @@ cc_test( deps = [ ":distortion", ":sfp", + ":shared", "@googletest//:gtest_main", "//:ops", "//:test_util", @@ -108,6 +118,7 @@ cc_library( textual_hdrs = ["nuq-inl.h"], deps = [ ":sfp", + ":shared", "@hwy//:hwy", "@hwy//hwy/contrib/sort:vqsort", ], @@ -127,6 +138,7 @@ cc_test( ":distortion", ":nuq", ":sfp", + ":shared", "@googletest//:gtest_main", "//:test_util", "@hwy//:hwy", @@ -137,11 +149,7 @@ cc_test( cc_library( name = "compress", - hdrs = [ - "compress.h", - "nuq.h", - "sfp.h", - ], + hdrs = ["compress.h"], textual_hdrs = [ "compress-inl.h", ], @@ -151,12 +159,35 @@ cc_library( ":io", ":nuq", ":sfp", + ":shared", + "//:fp_arith", "@hwy//:hwy", "@hwy//:stats", "@hwy//:thread_pool", ], ) +cc_test( + name = "compress_test", + size = "small", + srcs = ["compress_test.cc"], + features = ["fully_static_link"], + linkstatic = True, + local_defines = ["HWY_IS_TEST"], + # for test_suite. + tags = ["hwy_ops_test"], + deps = [ + ":compress", + ":distortion", + "@googletest//:gtest_main", + "//:test_util", + "@hwy//:hwy", + "@hwy//:hwy_test_util", + "@hwy//:nanobenchmark", + "@hwy//:thread_pool", + ], +) + # For internal experimentation cc_library( name = "analyze", @@ -190,6 +221,7 @@ cc_binary( srcs = ["compress_weights.cc"], deps = [ ":compress", + ":shared", ":weights_raw", # Placeholder for internal dep, do not remove., "//:args", diff --git a/compression/compress.h b/compression/compress.h index 23340dc..18f1725 100644 --- a/compression/compress.h +++ b/compression/compress.h @@ -30,10 +30,10 @@ #include "compression/blob_store.h" #include "compression/io.h" #include "compression/nuq.h" -#include "compression/sfp.h" +#include "compression/shared.h" // IWYU pragma: end_exports #include "compression/distortion.h" -#include "hwy/base.h" // hwy::bfloat16_t +#include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #if COMPRESS_STATS #include "hwy/stats.h" @@ -41,29 +41,20 @@ namespace gcpp { -using BF16 = hwy::bfloat16_t; - static inline const char* TypeName(float) { return "f32"; } static inline const char* TypeName(BF16) { return "b16"; } +static inline const char* TypeName(SfpStream) { return "SFP"; } +static inline const char* TypeName(NuqStream) { return "NUQ"; } -namespace detail { -// How many MatT are required to store `capacity` weights. For all but -// NuqStream, this is the same as `capacity`. For use by CompressedArray. +// Returns the number of `MatT` elements required to store `capacity` values, +// which must not be zero. template -constexpr size_t CompressedArrayLen(size_t capacity) { - return capacity; -} -template <> -constexpr size_t CompressedArrayLen(size_t capacity) { - return NuqStream::PackedEnd(capacity); -} -} // namespace detail - -// Returns the number of bytes required to store a compressed array with the -// given type and capacity. -template -constexpr size_t CompressedArraySize(size_t capacity) { - return detail::CompressedArrayLen(capacity) * sizeof(MatT); +constexpr size_t CompressedArrayElements(size_t capacity) { + if constexpr (hwy::IsSame, NuqStream>()) { + return NuqStream::PackedEnd(capacity); + } else { + return capacity; + } } // Compressed representation of floating-point elements. The array length may @@ -71,10 +62,6 @@ constexpr size_t CompressedArraySize(size_t capacity) { // implemented in SIMD code and are thus non-member functions. template class CompressedArray { - static constexpr size_t NumCompressed() { - return detail::CompressedArrayLen(kCapacity); - } - public: using value_type = MatT; @@ -100,11 +87,11 @@ class CompressedArray { constexpr size_t size() const { return kCapacity; } constexpr size_t CompressedSize() const { - return NumCompressed() * sizeof(MatT); + return data_.size() * sizeof(MatT); } private: - std::array data_; + std::array(kCapacity)> data_; // Blobs are at least kBlobAlign bytes anyway. float scale_[kBlobAlign / sizeof(float)]; }; diff --git a/compression/compress_test.cc b/compression/compress_test.cc new file mode 100644 index 0000000..00e678c --- /dev/null +++ b/compression/compress_test.cc @@ -0,0 +1,14 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. diff --git a/compression/compress_weights.cc b/compression/compress_weights.cc index 2ca7400..3106db9 100644 --- a/compression/compress_weights.cc +++ b/compression/compress_weights.cc @@ -21,9 +21,9 @@ #define HWY_TARGET_INCLUDE \ "compression/compress_weights.cc" // NOLINT #include "hwy/foreach_target.h" // IWYU pragma: keep -// Must come after foreach_target.h to avoid redefinition errors. -#include "compression/compress-inl.h" #include "hwy/highway.h" +// After highway.h +#include "compression/compress-inl.h" #ifndef GEMMA_COMPRESS_WEIGHTS_ONCE #define GEMMA_COMPRESS_WEIGHTS_ONCE @@ -38,9 +38,11 @@ #include // NOLINT #include "compression/io.h" // Path +#include "compression/shared.h" #include "compression/weights_raw.h" #include "gemma/common.h" // Model #include "gemma/weights.h" +#include "util/allocator.h" #include "util/args.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" @@ -57,11 +59,10 @@ float ScaleWeights(float* data, size_t len) { for (size_t i = 0; i < len; ++i) { maxabs = std::max(maxabs, std::abs(data[i])); } - const float kMaxRange = 1.875f; - if (maxabs <= kMaxRange) { + if (maxabs <= kMaxSFP) { return 1.0f; } - const float scale = maxabs / kMaxRange; + const float scale = maxabs / kMaxSFP; const float inv_scale = 1.0f / scale; for (size_t i = 0; i < len; ++i) { data[i] *= inv_scale; diff --git a/compression/distortion.h b/compression/distortion.h index c259ed4..114a534 100644 --- a/compression/distortion.h +++ b/compression/distortion.h @@ -30,9 +30,12 @@ namespace gcpp { // Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`, -// despite floating-point rounding. `sum` is already the best estimate, so do -// not actually add `err` to it. Knuth98/Moller65. Unlike Fast2Sum [Dekker71], -// this does not require any relative ordering of the exponents of a and b. +// despite floating-point rounding. `sum` is already the best estimate for the +// addition, so do not directly add `err` to it. +// +// Knuth98/Moller65. Unlike FastTwoSum, this does not require any relative +// ordering of the exponents of a and b. 6 ops. +// TODO: move to and use in Highway stats.h? template static inline T TwoSum(T a, T b, T& err) { const T sum = a + b; @@ -88,7 +91,6 @@ class DistortionStats { const float l1f = hwy::ScalarAbs(original - distorted); const double l1 = static_cast(l1f); s_l1_.Notify(l1f); - b_l1_.Notify(HWY_MIN(99, static_cast(l1f * 1E4))); if (l1f != 0.0f) { l1_.push_back(l1f); } @@ -102,7 +104,7 @@ class DistortionStats { // as much as an actual sign flip, so do not count them. n_sign_flip_ += ((original < 0.0f) != (distorted < 0.0f)) && !rounded_to_zero; - n_exact_ += (l1f == 0.0f); + n_exact_ += (original == distorted); n_rounded_to_zero += rounded_to_zero; } @@ -122,7 +124,6 @@ class DistortionStats { void Assimilate(const DistortionStats& other) { s_original_.Assimilate(other.s_original_); s_l1_.Assimilate(other.s_l1_); - b_l1_.Assimilate(other.b_l1_); sum_l1_.Assimilate(other.sum_l1_); sum_l1_rounded_.Assimilate(other.sum_l1_rounded_); l1_.insert(l1_.end(), other.l1_.begin(), other.l1_.end()); @@ -204,7 +205,6 @@ class DistortionStats { private: hwy::Stats s_original_; hwy::Stats s_l1_; - hwy::Bins<100> b_l1_; CascadedSummation sum_l1_; // all CascadedSummation sum_l1_rounded_; // only if rounded_to_zero std::vector l1_; diff --git a/compression/distortion_test.cc b/compression/distortion_test.cc index 7ee9b0a..00e026a 100644 --- a/compression/distortion_test.cc +++ b/compression/distortion_test.cc @@ -17,6 +17,7 @@ #include +#include "compression/shared.h" #include "util/test_util.h" #include "hwy/nanobenchmark.h" #include "hwy/tests/hwy_gtest.h" @@ -74,13 +75,13 @@ TEST(DistortionTest, TestDilution) { HWY_ASSERT(IsNear(0.001, stats.WeightedAverageL1())); // Now add a large difference: - stats.Notify(1.875f - 0.0625f, 1.875f); // max magnitude, 3-bit mantissa + stats.Notify(kMaxSFP - 0.0625f, kMaxSFP); // max magnitude, 3-bit mantissa // .. WeightedAverageL1 is closer to it. HWY_ASSERT(IsInside(0.020, 0.025, stats.WeightedAverageL1())); // Add a small and large difference: stats.Notify((1.75f - 0.125f) / 1024, 1.75f / 1024); // small, 2-bit mantissa - stats.Notify(-1.875f + 0.0625f, -1.875f); // larger negative + stats.Notify(-kMaxSFP + 0.0625f, -kMaxSFP); // larger negative // .. SNR is still barely affected. HWY_ASSERT(IsInside(890.0, 900.0, stats.GeomeanValueDivL1())); // .. WeightedAverageL1 is higher after another large error. diff --git a/compression/nuq-inl.h b/compression/nuq-inl.h index 5091946..11e9204 100644 --- a/compression/nuq-inl.h +++ b/compression/nuq-inl.h @@ -21,7 +21,7 @@ #include #include "compression/nuq.h" -#include "compression/sfp.h" +#include "compression/shared.h" #include "hwy/base.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_NUQ_INL_H_ diff --git a/compression/nuq.h b/compression/nuq.h index 08d162f..d7ae814 100644 --- a/compression/nuq.h +++ b/compression/nuq.h @@ -40,9 +40,8 @@ static constexpr size_t kGroupSize = 256; // Points to the *start* of a NUQ stream. Aligning the allocation (see // aligned_allocator.h) may be speed up decoding but is not required. // -// See go/streaming-weight-decode for background and design. Layout: first one -// table of kClusters entries per group, in ascending order of group index, -// then two packed indices per byte. +// Layout: first one table of kClusters entries per group, in ascending order +// of group index, then two packed indices per byte. // // Indices are stored in-order to enable vector-length agnostic decode, because // streams may be persisted to disk and used by other CPUs. @@ -54,26 +53,23 @@ static constexpr size_t kGroupSize = 256; #pragma pack(push, 1) struct NuqStream { // Returns offset of packed indices from the start of the stream. This matches - // the (padded) total table size because table entries are bytes. `capacity` - // is already a multiple of `kGroupSize`. + // the (padded) total table size because table entries are bytes. static constexpr size_t PackedStart(size_t capacity) { // Round up to avoid cache-line splits when loading indices. No effect on // size as long as capacity / kGroupSize is a multiple of 4. - return hwy::RoundUpTo((capacity / kGroupSize) * kClusters, 64); + return hwy::RoundUpTo(hwy::DivCeil(capacity, kGroupSize) * kClusters, 64); } // Returns number of NuqStream to allocate for the stream, which matches its // size in bytes. `capacity` is already a multiple of `kGroupSize`. static constexpr size_t PackedEnd(size_t capacity) { - return PackedStart(capacity) + capacity / 2; // two 4-bit indices per byte. + return PackedStart(capacity) + hwy::DivCeil(capacity, 2); // 2x 4-bit/byte } uint8_t byte; }; #pragma pack(pop) -static inline const char* TypeName(NuqStream) { return "NUQ"; } - // Storage for dynamic programming. There are two matrices; we use separate // allocations to avoid type punning. template @@ -101,7 +97,7 @@ struct ClusterBuf { num = new_num; const size_t num_groups = hwy::DivCeil(num, kGroupSize); centers = hwy::AllocateAligned(num_groups * kClusters); - idx = hwy::AllocateAligned(num); + idx = hwy::AllocateAligned(hwy::RoundUpTo(num, kGroupSize)); } AlignedMatrix d; diff --git a/compression/python/compression_clif_aux.cc b/compression/python/compression_clif_aux.cc index 334530a..de36bfb 100644 --- a/compression/python/compression_clif_aux.cc +++ b/compression/python/compression_clif_aux.cc @@ -46,7 +46,7 @@ class SbsWriterImpl : public WriterInterface { SbsWriterImpl() : pool_(0), compressor_(pool_) {} void Insert(std::string name, absl::Span weights) override { - const size_t out_size = CompressedArraySize(weights.size()); + const size_t out_size = CompressedArrayElements(weights.size()); sfp_streams_.push_back(std::vector(out_size)); compressor_.Insert(name.data(), weights.data(), weights.size(), working_set_, out_size, @@ -54,7 +54,7 @@ class SbsWriterImpl : public WriterInterface { } void InsertNUQ(std::string name, absl::Span weights) override { - const size_t out_size = CompressedArraySize(weights.size()); + const size_t out_size = CompressedArrayElements(weights.size()); nuq_streams_.push_back(std::vector(out_size)); compressor_.Insert(name.data(), weights.data(), weights.size(), working_set_, out_size, @@ -64,7 +64,7 @@ class SbsWriterImpl : public WriterInterface { void InsertBfloat16(std::string name, absl::Span weights) override { const size_t out_size = - CompressedArraySize(weights.size()); + CompressedArrayElements(weights.size()); bf16_streams_.push_back(std::vector(out_size)); compressor_.Insert(name.data(), weights.data(), weights.size(), working_set_, out_size, diff --git a/compression/sfp-inl.h b/compression/sfp-inl.h index 3a40527..78c941f 100644 --- a/compression/sfp-inl.h +++ b/compression/sfp-inl.h @@ -20,7 +20,7 @@ #include #include -#include "compression/sfp.h" +#include "compression/shared.h" #include "hwy/base.h" #endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_INL_H_ diff --git a/compression/sfp.h b/compression/sfp.h deleted file mode 100644 index 332ca43..0000000 --- a/compression/sfp.h +++ /dev/null @@ -1,51 +0,0 @@ -// Copyright 2023 Google LLC -// SPDX-License-Identifier: Apache-2.0 -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_ -#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_ - -// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 -// inputs that combines the advantages of e4m3 and e5m2 into a single format. -// It supports seeking at a granularity of 1, decoding to bf16/f32, and a -// fused decode/dot product with bf16/f32 vectors. - -#include - -namespace gcpp { - -// Points to the *start* of an SFP stream. Values are stored in-order to enable -// vector-length agnostic seeking, because streams may be written to disk for -// loading on other CPUs. -// -// Characteristics: -// - 24-bit dynamic range, with max exponent 2^0. -// - 3 bit mantissa for values >= 2^-7, otherwise 2. -// -// This is faster to decode than a straightforward implementation of eXmY, in -// part because SFP does not require subnormals. Unlike OCP MX, it also does not -// require side information (shared exponents). -// -// Although the representation could probably be shrunk to 6-7 bits, more -// savings can be had by non-uniform clustering - see nuq.h. -#pragma pack(push, 1) -struct SfpStream { - uint8_t byte; -}; -#pragma pack(pop) - -static inline const char* TypeName(SfpStream) { return "SFP"; } - -} // namespace gcpp -#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SFP_H_ diff --git a/compression/sfp_test.cc b/compression/sfp_test.cc index 6e1f5ce..da4f220 100644 --- a/compression/sfp_test.cc +++ b/compression/sfp_test.cc @@ -18,8 +18,6 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif -#include "compression/sfp.h" - #include #include #include @@ -27,20 +25,21 @@ #include #include "compression/distortion.h" +#include "compression/shared.h" #include "util/test_util.h" #include "hwy/aligned_allocator.h" #include "hwy/base.h" +#include "hwy/tests/hwy_gtest.h" #include "hwy/timer.h" // clang-format off #undef HWY_TARGET_INCLUDE #define HWY_TARGET_INCLUDE "compression/sfp_test.cc" // NOLINT // clang-format on #include "hwy/foreach_target.h" // IWYU pragma: keep -// Any highway.h must come after foreach_target.h +#include "hwy/highway.h" +// After highway.h #include "compression/sfp-inl.h" #include "ops/dot-inl.h" -#include "hwy/highway.h" -#include "hwy/tests/hwy_gtest.h" #include "hwy/tests/test_util-inl.h" HWY_BEFORE_NAMESPACE(); diff --git a/compression/shared.h b/compression/shared.h new file mode 100644 index 0000000..5f8b173 --- /dev/null +++ b/compression/shared.h @@ -0,0 +1,94 @@ +// Copyright 2023 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Definitions shared between the public compress-inl.h interface and the +// sfp-inl.h and nuq-inl.h implementation details. + +#ifndef THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ +#define THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ + +#include + +#include "hwy/base.h" // hwy::bfloat16_t + +namespace gcpp { + +using BF16 = hwy::bfloat16_t; + +// Switching Floating Point: a hybrid 8-bit float representation of bf16/f32 +// inputs that combines the advantages of e4m3 and e5m2 into a single format. +// It supports seeking at a granularity of 1 and decoding to bf16/f32. +// +// Characteristics: +// - 24-bit dynamic range, with max exponent 2^0. +// - 3 bit mantissa for values >= 2^-7, otherwise 2. +// +// A pointer to this is the *start* of an SFP stream. Values are stored +// in-order to enable vector-length agnostic seeking, because streams may be +// written to disk for loading on other CPUs. +// +// This is faster to decode than a straightforward implementation of eXmY, in +// part because SFP does not require subnormals. Unlike OCP MX, it also does not +// require side information (shared exponents). +// +// Although the representation could probably be shrunk to 6-7 bits, more +// savings can be had by non-uniform clustering - see nuq.h. +#pragma pack(push, 1) +struct SfpStream { + uint8_t byte; +}; +#pragma pack(pop) + +// Largest possible input magnitude: 1.111 * 2^0. This could be increased by +// shifting the value range (exponent bias). +constexpr float kMaxSFP = 1.875f; + +// Non-owning view of packed elements. Shortens argument lists. +// +// Callers typically also pass an `ofs` starting offset. This is not folded +// into `ptr` because NUQ consists of two separate streams. To discourage direct +// use of `ptr` without that offset, we define a separate class instead of +// reusing `hwy::Span`. +template +struct PackedSpan { + void BoundsCheck(size_t packed_ofs, size_t num) const { + HWY_DASSERT(packed_ofs + num <= size); + (void)size; + } + + Packed* HWY_RESTRICT ptr; + size_t size; // for BoundsCheck and nuq-inl.h HWY_ASSERT. +}; + +// Avoids spelling out the template parameter in every call. +template +HWY_INLINE PackedSpan MakeSpan(Packed* ptr, size_t size) { + return {ptr, size}; +} + +template +HWY_INLINE PackedSpan MakeConstSpan(Packed* ptr, size_t size) { + return {ptr, size}; +} + +// "Implicit" conversion from a PackedSpan to PackedSpan, used in +// `RMSNormInplace` and compression tests. +template +HWY_INLINE PackedSpan MakeConst(PackedSpan packed) { + return {packed.ptr, packed.size}; +} + +} // namespace gcpp +#endif // THIRD_PARTY_GEMMA_CPP_COMPRESSION_SHARED_H_ diff --git a/gemma/activations.h b/gemma/activations.h index f8cb4dd..b2113da 100644 --- a/gemma/activations.h +++ b/gemma/activations.h @@ -55,11 +55,6 @@ struct Activations { // Rope RowVectorBatch inv_timescale; - // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into - // per-thread storage. - // TODO: remove once MatVec is no longer used. - RowVectorBatch even_odd; - MatMulEnv env; // Multi-Head Attention? @@ -123,9 +118,6 @@ struct Activations { inv_timescale = CreateInvTimescale(); - const size_t num_lp = pools.NumLP(); - even_odd = RowVectorBatch(1, kModelDim * num_lp); - env = MatMulEnv(pools); } }; diff --git a/gemma/gemma-inl.h b/gemma/gemma-inl.h index cee376f..3dab525 100644 --- a/gemma/gemma-inl.h +++ b/gemma/gemma-inl.h @@ -192,8 +192,7 @@ HWY_NOINLINE void GriffinRecurrent( float* out_ptr = activations.att_sums.Batch(batch_idx); MatVecAdd( layer_weights->griffin.linear_out_w, 0, x, - layer_weights->griffin.linear_out_biases.data_scale1(), - activations.even_odd.All(), out_ptr, pool); + layer_weights->griffin.linear_out_biases.data_scale1(), out_ptr, pool); } } @@ -283,8 +282,8 @@ class GemmaAttention { float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset; // KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v). MatVec( - layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, - activations_.even_odd.All(), kv, pool_); + layer_weights_.qkv_einsum_w, kHeads * kQKVDim * kModelDim, x, kv, + pool_); } } } diff --git a/ops/fp_arith-inl.h b/ops/fp_arith-inl.h new file mode 100644 index 0000000..609f1fe --- /dev/null +++ b/ops/fp_arith-inl.h @@ -0,0 +1,152 @@ +// Copyright 2024 Google LLC +// SPDX-License-Identifier: Apache-2.0 +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +// Building blocks for floating-point arithmetic. + +// Include guard for (potentially) SIMD code. +#if defined(THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE) == defined(HWY_TARGET_TOGGLE) +#ifdef THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE +#undef THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE +#else +#define THIRD_PARTY_GEMMA_CPP_FP_ARITH_TOGGLE +#endif + +#include "hwy/highway.h" + +HWY_BEFORE_NAMESPACE(); +namespace gcpp { +namespace HWY_NAMESPACE { +namespace hn = hwy::HWY_NAMESPACE; + +//------------------------------------------------------------------------------ +// Exact multiplication + +namespace detail { + +// Returns non-overlapping `x` and `y` such that `x + y` = `f` and |x| >= |y|. +// Notation from Algorithm 3.1 in Handbook of Floating-Point Arithmetic. 4 ops. +template > +static HWY_INLINE void VeltkampSplit(DF df, VF a, VF& x, VF& y) { + using TF = hn::TFromD; + constexpr int t = hwy::MantissaBits() + 1; // = -log2(epsilon) + constexpr int s = hwy::DivCeil(t, 2); + const VF factor = hn::Set(df, hwy::ConvertScalarTo((1ULL << s) + 1)); + const VF c = hn::Mul(factor, a); + x = hn::Sub(c, hn::Sub(c, a)); + y = hn::Sub(a, x); +} + +} // namespace detail + +// Returns `prod` and `err` such that `prod + err` is exactly equal to `a * b`, +// despite floating-point rounding, assuming that `err` is not subnormal, i.e., +// the sum of exponents >= min exponent + mantissa bits. 2..17 ops. Useful for +// compensated dot products and polynomial evaluation. +template > +static HWY_INLINE VF TwoProducts(DF df, VF a, VF b, VF& err) { + const VF prod = hn::Mul(a, b); + if constexpr (HWY_NATIVE_FMA) { + err = hn::MulSub(a, b, prod); + } else { + // Non-FMA fallback: we assume these calculations do not overflow. + VF a1, a2, b1, b2; + detail::VeltkampSplit(df, a, a1, a2); + detail::VeltkampSplit(df, b, b1, b2); + const VF m = hn::Sub(prod, hn::Mul(a1, b1)); + const VF n = hn::Sub(m, hn::Mul(a2, b1)); + const VF o = hn::Sub(n, hn::Mul(a1, b2)); + err = hn::Sub(hn::Mul(a2, b2), o); + } + return prod; +} + +//------------------------------------------------------------------------------ +// Exact addition + +// Returns `sum` and `err` such that `sum + err` is exactly equal to `a + b`, +// despite floating-point rounding. `sum` is already the best estimate for the +// addition, so do not directly add `err` to it. `UpdateCascadedSums` instead +// accumulates multiple `err`, which are then later added to the total `sum`. +// +// Knuth98/Moller65. Unlike FastTwoSums, this does not require any relative +// ordering of the exponents of a and b. 6 ops. +template > +static HWY_INLINE VF TwoSums(DF /*df*/, VF a, VF b, VF& err) { + const VF sum = hn::Add(a, b); + const VF a2 = hn::Sub(sum, b); + const VF b2 = hn::Sub(sum, a2); + const VF err_a = hn::Sub(a, a2); + const VF err_b = hn::Sub(b, b2); + err = hn::Add(err_a, err_b); + return sum; +} + +// As above, but only exact if the exponent of `a` >= that of `b`, which is the +// case if |a| >= |b|. Dekker71, also used in Kahan65 compensated sum. 3 ops. +template > +static HWY_INLINE VF FastTwoSums(DF /*df*/, VF a, VF b, VF& err) { + const VF sum = hn::Add(a, b); + const VF b2 = hn::Sub(sum, a); + err = hn::Sub(b, b2); + return sum; +} + +//------------------------------------------------------------------------------ +// Cascaded summation (twice working precision) + +// Accumulates numbers with about twice the precision of T using 7 * n FLOPS. +// Rump/Ogita/Oishi08, Algorithm 6.11 in Handbook of Floating-Point Arithmetic. +// +// Because vectors generally cannot be wrapped in a class, we use functions. +// `sum` and `sum_err` must be initially zero. Each lane is an independent sum. +// To reduce them into a single result, use `ReduceCascadedSum`. +template > +void UpdateCascadedSums(DF df, VF v, VF& sum, VF& sum_err) { + VF err; + sum = TwoSums(df, sum, v, err); + sum_err += err; +} + +// Combines two cascaded sum vectors, typically from unrolling/parallelization. +template > +void AssimilateCascadedSums(DF df, const VF& other_sum, const VF& other_sum_err, + VF& sum, VF& sum_err) { + UpdateCascadedSums(df, other_sum, sum, sum_err); + sum_err += other_sum_err; +} + +// Reduces cascaded sums, to a single value. Slow, call outside of loops. +template > +hn::TFromD ReduceCascadedSums(DF df, const VF sum, VF sum_err) { + const size_t N = hn::Lanes(df); + using TF = hn::TFromD; + TF total = TF{0.0}; + TF total_err = TF{0.0}; + for (size_t i = 0; i < N; ++i) { + TF err; + total = TwoSum(total, hn::ExtractLane(sum, i), err); + total_err += err; + } + return total + total_err + hn::ReduceSum(df, sum_err); +} + +// NOLINTNEXTLINE(google-readability-namespace-comments) +} // namespace HWY_NAMESPACE +} // namespace gcpp +HWY_AFTER_NAMESPACE(); + +#endif // NOLINT diff --git a/ops/gemma_matvec_test.cc b/ops/gemma_matvec_test.cc index d923204..012b24c 100644 --- a/ops/gemma_matvec_test.cc +++ b/ops/gemma_matvec_test.cc @@ -115,15 +115,13 @@ void TestMatVecAdd() { GenerateMat(0, pool); hwy::AlignedFreeUniquePtr vec = GenerateVec(0); hwy::AlignedFreeUniquePtr add = GenerateVec(0); - hwy::AlignedFreeUniquePtr even_odd = - hwy::AllocateAligned(kInner * pool.NumWorkers()); hwy::AlignedFreeUniquePtr expected_out = SimpleMatVecAdd(mat, vec, add); hwy::AlignedFreeUniquePtr actual_out = hwy::AllocateAligned(kOuter); - HWY_ASSERT(vec && add && even_odd && expected_out && actual_out); - MatVecAdd(mat, 0, vec.get(), add.get(), even_odd.get(), - actual_out.get(), pool); + HWY_ASSERT(vec && add && expected_out && actual_out); + MatVecAdd(mat, 0, vec.get(), add.get(), actual_out.get(), + pool); AssertClose(actual_out, expected_out); } diff --git a/ops/matvec-inl.h b/ops/matvec-inl.h index 3d6bbcd..a1c9e57 100644 --- a/ops/matvec-inl.h +++ b/ops/matvec-inl.h @@ -21,8 +21,6 @@ #include #include -#include "compression/compress.h" -#include "compression/sfp.h" #include "hwy/base.h" #include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/profiler.h" @@ -48,37 +46,6 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; -HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, - const size_t size, float* HWY_RESTRICT out) { - const hn::ScalableTag df; - const hn::Repartition dbf16; - - HWY_DASSERT(size % hn::Lanes(dbf16) == 0); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - - for (size_t i = 0; i < size; i += hn::Lanes(dbf16)) { - const auto interleaved = hn::LoadU(dbf16, vec_aligned + i); - hn::Store(hn::PromoteEvenTo(df, interleaved), df, out + i); - hn::Store(hn::PromoteOddTo(df, interleaved), df, out + i + hn::Lanes(df)); - } -} - -HWY_INLINE void ToEvenOddF32(const float* HWY_RESTRICT vec_aligned, - const size_t size, float* HWY_RESTRICT out) { - const hn::ScalableTag df; - using VF = hn::Vec; - - HWY_DASSERT(size % (hn::Lanes(df) * 2) == 0); - HWY_DASSERT(hn::IsAligned(df, vec_aligned)); - - VF vec0, vec1; - for (size_t i = 0; i < size; i += hn::Lanes(df) * 2) { - hn::LoadInterleaved2(df, vec_aligned + i, vec0, vec1); - hn::Store(vec0, df, out + i); - hn::Store(vec1, df, out + i + hn::Lanes(df)); - } -} - // Simple version without tiling nor threading, but two offsets/outputs and // always with addition. template df; for (size_t idx_row = 0; idx_row < kOuter; ++idx_row) { const size_t row_ofs0 = mat_ofs0 + (idx_row)*kInner; const size_t row_ofs1 = mat_ofs1 + (idx_row)*kInner; out0[idx_row] = hwy::ConvertScalarTo(add0[idx_row]) + - Dot(df, mat, row_ofs0, vec_aligned, kInner); + Dot(df, mat, row_ofs0, vec_aligned, kInner); out1[idx_row] = hwy::ConvertScalarTo(add1[idx_row]) + - Dot(df, mat, row_ofs1, vec_aligned, kInner); + Dot(df, mat, row_ofs1, vec_aligned, kInner); } } @@ -125,7 +91,7 @@ namespace detail { // For each i = [0, num_rows), compute partial (length `num_cols`) dot product // of row i with `vec_aligned` and add into `out[i]`. The upper-left // coordinate of the tile is r0, c0. -template +template HWY_INLINE void AccumulatePartialDotProducts( DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t c0, size_t num_rows, size_t num_cols, @@ -133,15 +99,14 @@ HWY_INLINE void AccumulatePartialDotProducts( for (size_t idx_row = 0; idx_row < num_rows; ++idx_row) { const size_t row_ofs = mat_ofs + (r0 + idx_row) * mat_stride; out[idx_row] += - Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } } // Same as AccumulatePartialDotProducts, but sets out[i] to the first partial // dot product + init (if kInit), which avoids having to zero-initialize and // accumulate. -template +template HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t c0, @@ -154,10 +119,10 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, if constexpr (kInit) { out[idx_row] = hwy::ConvertScalarTo(init[idx_row + r0]) + - Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } else { out[idx_row] = - Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); + Dot(df, mat, row_ofs + c0, vec_aligned + c0, num_cols); } } } @@ -166,8 +131,7 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat, // horizontal strip of the entire matrix); the result is the full dot product // for rows r in [r0, r0 + num_rows) + optionally the add vector, which we // store into in out[r - r0]. -template +template HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0, size_t num_rows, @@ -176,56 +140,25 @@ HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat, float* HWY_RESTRICT out) { // Tall and skinny: set `out` to the single dot product. if (mat_stride < MaxCols()) { - SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, - 0, num_rows, mat_stride, - vec_aligned, add, out); + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, + num_rows, mat_stride, vec_aligned, add, + out); return; } // We have at least MaxCols, so start by setting `out` to that: - SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, - num_rows, MaxCols(), vec_aligned, - add, out); + SetFirstPartialDotProducts(df, mat, mat_ofs, mat_stride, r0, 0, + num_rows, MaxCols(), vec_aligned, add, out); // For further multiples of MaxCols, accumulate. Remainders handled below. size_t c0 = MaxCols(); for (; c0 <= mat_stride - MaxCols(); c0 += MaxCols()) { - AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, - num_rows, MaxCols(), vec_aligned, out); + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, + MaxCols(), vec_aligned, out); } if (c0 < mat_stride) { // Final cols - AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, - num_rows, mat_stride - c0, vec_aligned, - out); - } -} - -template -HWY_INLINE void MatVecAddInner(const ArrayT& mat, const size_t mat_ofs, - const VecT* HWY_RESTRICT const vec_aligned, - const AddT* HWY_RESTRICT const add, - float* HWY_RESTRICT out, hwy::ThreadPool& pool) { - const hn::ScalableTag df; - constexpr size_t kRowsPerStrip = RowsPerStrip(); - constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - - // For each entire strip. - pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { - PROFILER_ZONE("MatVec.lambda"); - const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add, - out + r0); - }); - - // Remaining rows - const size_t r0 = kNumStrips * kRowsPerStrip; - if (r0 < kOuter) { - PROFILER_ZONE("MatVec remainder"); - const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( - df, mat, mat_ofs, kInner, r0, num_rows, vec_aligned, add, out + r0); + AccumulatePartialDotProducts(df, mat, mat_ofs, mat_stride, r0, c0, num_rows, + mat_stride - c0, vec_aligned, out); } } @@ -238,28 +171,30 @@ template float does not benefit enough to recoup the cost of ToEvenOddF32. - if constexpr (CompressTraits::kSupportsEvenOdd && - hwy::IsSameEither() && - !(hwy::IsSame() && - hwy::IsSame())) { - ToEvenOddF32(vec_aligned, kInner, even_odd); - detail::MatVecAddInner( - mat, mat_ofs, even_odd, add, out, pool); - return; - } -#else - (void)even_odd; -#endif + const hn::ScalableTag df; + constexpr size_t kRowsPerStrip = RowsPerStrip(); + constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - detail::MatVecAddInner( - mat, mat_ofs, vec_aligned, add, out, pool); + // For each entire strip. + pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { + PROFILER_ZONE("MatVec.lambda"); + const size_t r0 = strip * kRowsPerStrip; + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, + kRowsPerStrip, vec_aligned, add, + out + r0); + }); + + // Remaining rows + const size_t r0 = kNumStrips * kRowsPerStrip; + if (r0 < kOuter) { + PROFILER_ZONE("MatVec remainder"); + const size_t num_rows = kOuter - r0; + detail::FullDotProductsForStrip(df, mat, mat_ofs, kInner, r0, + num_rows, vec_aligned, add, out + r0); + } } // With addition @@ -268,21 +203,19 @@ template (mat, mat_ofs, vec_aligned, add, - even_odd, out, pool); + out, pool); } // Without addition template HWY_INLINE void MatVec(const ArrayT& mat, const size_t mat_ofs, const VecT* HWY_RESTRICT const vec_aligned, - float* HWY_RESTRICT even_odd, float* HWY_RESTRICT out, - hwy::ThreadPool& pool) { + float* HWY_RESTRICT out, hwy::ThreadPool& pool) { MatVecT(mat, mat_ofs, vec_aligned, /*add=*/static_cast(nullptr), - even_odd, out, pool); + out, pool); } // Two matrices, same vector @@ -300,18 +233,17 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1, const hn::ScalableTag df; constexpr size_t kRowsPerStrip = RowsPerStrip(); constexpr size_t kNumStrips = kOuter / kRowsPerStrip; - constexpr bool kVecIsEvenOdd = false; // For each entire strip. pool.Run(0, kNumStrips, [&](const uint64_t strip, size_t thread) HWY_ATTR { PROFILER_ZONE("TwoMatVec.lambda"); const size_t r0 = strip * kRowsPerStrip; - detail::FullDotProductsForStrip( - df, mat0, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add0, - out0 + r0); - detail::FullDotProductsForStrip( - df, mat1, mat_ofs, kInner, r0, kRowsPerStrip, vec_aligned, add1, - out1 + r0); + detail::FullDotProductsForStrip(df, mat0, mat_ofs, kInner, r0, + kRowsPerStrip, vec_aligned, add0, + out0 + r0); + detail::FullDotProductsForStrip(df, mat1, mat_ofs, kInner, r0, + kRowsPerStrip, vec_aligned, add1, + out1 + r0); }); // Remaining rows @@ -319,9 +251,9 @@ HWY_NOINLINE void TwoMatVecT(const ArrayT& mat0, const ArrayT& mat1, if (r0 < kOuter) { PROFILER_ZONE("TwoMatVec remainder"); const size_t num_rows = kOuter - r0; - detail::FullDotProductsForStrip( + detail::FullDotProductsForStrip( df, mat0, mat_ofs, kInner, r0, num_rows, vec_aligned, add0, out0 + r0); - detail::FullDotProductsForStrip( + detail::FullDotProductsForStrip( df, mat1, mat_ofs, kInner, r0, num_rows, vec_aligned, add1, out1 + r0); } }