// Copyright 2024 Google LLC // SPDX-License-Identifier: Apache-2.0 // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // https://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #include #include "compression/compress.h" // IWYU pragma: keep, b/conditionally used #include "ops/matmul.h" // IWYU pragma: export // Include guard for (potentially) SIMD code. #if defined(THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE) == defined(HWY_TARGET_TOGGLE) #ifdef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE #undef THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE #else #define THIRD_PARTY_GEMMA_CPP_MATMUL_TOGGLE #endif #include "hwy/highway.h" // After highway.h #include "compression/compress-inl.h" #include "hwy/contrib/math/math-inl.h" HWY_BEFORE_NAMESPACE(); namespace gcpp { namespace HWY_NAMESPACE { namespace hn = hwy::HWY_NAMESPACE; // The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of // loads, we reuse the same A row for several B columns, which are also loaded // once for several rows of C. Thus we produce one 'tile' of C at a time of // dimensions `kRegRows` x `kRegCols`. The Reg naming is because these are // limited by the number of registers: 32 for NEON/SVE/AVX-512. `kRegCols` == 4 // enables the `StoreInterleaved4` transpose in `AddHorizontalSums`. We assume // and verify that `C.cols % kRegCols == 0`. constexpr size_t kRegCols = 4; // Choosing `kRegRows == kRegCols` minimizes the ratio of loads to FMA, because // we load `kRegCols + kRegRows` vectors per `kRegRows * kRegCols` element tile. // In general, `batch_size` (C rows) is not a multiple of `kRegRows`. Thus // functions that load or store a tile are parameterized on `kNumRows`, which is // generally `kRegRows`, but `batch_size % kRegRows` on the last row (if != 0). constexpr size_t kRegRows = kRegCols; // NEON_BF16/SVE/AVX3_ZEN4 have instructions for bf16 * bf16 + f32 which are // more efficient than f32 * f32 + f32 because they process twice as many lanes // at a time. Any combination of A and B can be bf16: activations may already be // bf16, and weights can be decompressed to bf16. // // The corresponding op is `ReordenWidenMulAccumulate`, and it is always // supported, but only useful if it returns a single vector of pairwise sums // `a[0] * b[0] + a[1] * b[1]`. On other targets, `ReordenWidenMulAccumulate` // insteads return `a[1] * b[1]` in its `sum1` output. We cannot afford to keep // a `sum1` for each of the `kRegRows * kRegCols` C vectors, and it would be // expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B // to bf16 if the native op is available. This will actually demote f32 // activations to bf16. Otherwise, we decompress to f32 and use normal FMA. using MulT = hwy::If; // Loads two vectors at a time with element type MulT from a row of transposed // B. Called in a loop over col_ab. No bounds checking because `kRow` is // actually from B columns, which we checked is a multiple of `kRegCols`. template class BRow { static_assert(kRow < kRegRows); // which unrolled instance we are using TraitsB = CompressTraits; public: BRow(const Mat& B, size_t row_b) : B_(B.ptr), B_ofs_(B.Row(row_b + kRow)) {} template > HWY_INLINE void Load2(DM d, size_t col_ab, VM& b0, VM& b1) const { static_assert(hwy::IsSame, MulT>()); TraitsB::Decompress2(d, B_, B_ofs_ + col_ab, b0, b1); } private: const MatTB* HWY_RESTRICT B_; const size_t B_ofs_; }; // Loads *two* row vectors from A via `Decompress2`, multiplies element-wise // with `kRegRows` x 2 row vectors from transposed B, and adds them to // `kRegRows` x `kRegCols` C vectors. The lanes of `C[r,c]` are thus a subset of // the terms of the dot products that make up the MatMul result at `r,c`. // No-op for the bottom-most tile where kRow >= kNumRows. // // This approach is atypical because it requires a horizontal sum, for which we // introduce a fast and new(?) vector-length agnostic 'transpose', see // `AddHorizontalSums`. Most MatMul instead broadcast one element from A and // multiply with one element from N columns in B to obtain N columns of C. // This is a poor fit for our setting: // - `CompressTraits` decompresses two vectors at a time; // - B is column-major, so unit-stride SIMD loads return a column, not values // from different columns, i.e. a row. // Both could be fixed in a packing stage, which is not implemented yet, and // might not be necessary otherwise. However, `ReorderWidenMulAccumulate` is // important for bf16 performance and incompatible with the conventional // approach, because its pairwise adds would add together unrelated terms. // By contrast, pairwise adds are fine when our C lanes are the terms of a // single dot product, which can be reordered or pre-reduced. template class ALoadAccumulate { static_assert(kRow < kRegRows); // which unrolled instance we are using TraitsA = CompressTraits; public: ALoadAccumulate(const Mat& A, size_t row_ac) : A_(A.ptr), A_ofs_(A.Row(row_ac + kRow)) {} // First iteration, col_ab = 0: initialize C0..3 instead of updating them. template , HWY_IF_F32_D(DM)> HWY_INLINE void First(DM dm, // const VM b00, const VM b01, const VM b10, const VM b11, const VM b20, const VM b21, const VM b30, const VM b31, VM& C0, VM& C1, VM& C2, VM& C3) const { static_assert(kNumRows <= kRegRows); // How many rows actually present if constexpr (kRow < kNumRows) { VM a0, a1; TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1); static_assert(kRegCols == 4); C0 = hn::Mul(a0, b00); C1 = hn::Mul(a0, b10); C2 = hn::Mul(a0, b20); C3 = hn::Mul(a0, b30); C0 = hn::MulAdd(a1, b01, C0); C1 = hn::MulAdd(a1, b11, C1); C2 = hn::MulAdd(a1, b21, C2); C3 = hn::MulAdd(a1, b31, C3); } } // Same as above, only called if MulT == BF16. template , HWY_IF_BF16_D(DM), class DF = hn::Repartition, class VF = hn::Vec> HWY_INLINE void First(DM dm, // const VM b00, const VM b01, const VM b10, const VM b11, const VM b20, const VM b21, const VM b30, const VM b31, VF& C0, VF& C1, VF& C2, VF& C3) const { static_assert(kNumRows <= kRegRows); // How many rows actually present if constexpr (kRow < kNumRows) { VM a0, a1; TraitsA::Decompress2(dm, A_, A_ofs_, a0, a1); const DF df; VF unused_sum1 = hn::Zero(df); static_assert(kRegCols == 4); C0 = hn::WidenMulPairwiseAdd(df, a0, b00); C1 = hn::WidenMulPairwiseAdd(df, a0, b10); C2 = hn::WidenMulPairwiseAdd(df, a0, b20); C3 = hn::WidenMulPairwiseAdd(df, a0, b30); C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); // Ensure sum1 was indeed unused. HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); } } // Non-first iteration: accumulate into C0..3. template , HWY_IF_F32_D(DM)> HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01, const VM b10, const VM b11, const VM b20, const VM b21, const VM b30, const VM b31, VM& C0, VM& C1, VM& C2, VM& C3) const { static_assert(kNumRows <= kRegRows); // How many rows actually present HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration. if constexpr (kRow < kNumRows) { VM a0, a1; TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); static_assert(kRegCols == 4); C0 = hn::MulAdd(a0, b00, C0); C1 = hn::MulAdd(a0, b10, C1); C2 = hn::MulAdd(a0, b20, C2); C3 = hn::MulAdd(a0, b30, C3); C0 = hn::MulAdd(a1, b01, C0); C1 = hn::MulAdd(a1, b11, C1); C2 = hn::MulAdd(a1, b21, C2); C3 = hn::MulAdd(a1, b31, C3); } } // Same as above, only called if MulT == BF16. template , HWY_IF_BF16_D(DM), class DF = hn::Repartition, class VF = hn::Vec> HWY_INLINE void Next(DM dm, size_t col_ab, const VM b00, const VM b01, const VM b10, const VM b11, const VM b20, const VM b21, const VM b30, const VM b31, VF& C0, VF& C1, VF& C2, VF& C3) const { static_assert(kNumRows <= kRegRows); // How many rows actually present HWY_DASSERT(col_ab >= 2 * hn::Lanes(dm)); // Should not be first iteration. if constexpr (kRow < kNumRows) { VM a0, a1; TraitsA::Decompress2(dm, A_, A_ofs_ + col_ab, a0, a1); const DF df; hn::Vec unused_sum1 = hn::Zero(df); static_assert(kRegCols == 4); C0 = hn::ReorderWidenMulAccumulate(df, a0, b00, C0, unused_sum1); C1 = hn::ReorderWidenMulAccumulate(df, a0, b10, C1, unused_sum1); C2 = hn::ReorderWidenMulAccumulate(df, a0, b20, C2, unused_sum1); C3 = hn::ReorderWidenMulAccumulate(df, a0, b30, C3, unused_sum1); C0 = hn::ReorderWidenMulAccumulate(df, a1, b01, C0, unused_sum1); C1 = hn::ReorderWidenMulAccumulate(df, a1, b11, C1, unused_sum1); C2 = hn::ReorderWidenMulAccumulate(df, a1, b21, C2, unused_sum1); C3 = hn::ReorderWidenMulAccumulate(df, a1, b31, C3, unused_sum1); // Ensure sum1 was indeed unused. HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df)))); } } private: const MatTA* HWY_RESTRICT A_; const size_t A_ofs_; }; // ALoadAccumulate // Sets a `kRegRows` x `kRegCols` tile of C to `add[add_ofs + c]` if kAdd, // otherwise 0. // `add` has no scale and is a row vector with A.cols entries if `kAdd`, // otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB, // hence we pass it as a separate argument. template HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs, float* HWY_RESTRICT pos_c, size_t stride_c) { const hn::FixedTag d4; for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) { if constexpr (kAdd) { hn::StoreU(hn::LoadU(d4, add + add_ofs), d4, pos_c + r * stride_c); } else { hn::StoreU(hn::Zero(d4), d4, pos_c + r * stride_c); } } } // Accumulates into a tile of C. template class AddHorizontalSums { // These helper functions hoist if() out of the main code below. They have no // effect if kRow >= kNumRows. template > static void MaybeStoreInterleaved4(DF df, size_t N, VF Cr0, VF Cr1, VF Cr2, VF Cr3, float* HWY_RESTRICT buf) { if constexpr (kRow < kNumRows) { hn::StoreInterleaved4(Cr0, Cr1, Cr2, Cr3, df, buf + 4 * kRow * N); } } // Note: N is the number of lanes in the StoreInterleaved4 vectors, not V4. template > static V4 MaybeLoad(D4 df, size_t N, const float* HWY_RESTRICT buf) { if constexpr (kRow < kNumRows) { return hn::Load(df, buf + 4 * kRow * N); } else { return hn::Zero(df); } } template > static V4 MaybeAdd(D4 df, size_t N, V4 sum, const float* HWY_RESTRICT buf) { if constexpr (kRow < kNumRows) { return hn::Add(sum, hn::Load(df, buf + 4 * kRow * N)); } else { return sum; } } template > static void MaybeMulAdd(D4 df, V4 sum, V4 scale, float* HWY_RESTRICT tile_c, const size_t stride_c) { if constexpr (kRow < kNumRows) { const V4 prev_c = hn::LoadU(df, tile_c + kRow * stride_c); hn::StoreU(hn::MulAdd(sum, scale, prev_c), df, tile_c + kRow * stride_c); } } public: // Adds the contribution from `Crc` accumulators to the 4x4 tile of C whose // top left is `tile_c`, after multiplying by `scale`, which is the product of // the scales of A and B. C is always f32 to ensure sufficient precision. // // `Crc` are the 16 combinations of an A row vector indexed by `r`, times a // B column vector indexed by `c`. Their elements are thus a subset of the // terms of the dot product constituting the final `C[r, c]` result. Thus we // compute the horizontal sums of each `Crc`. The elements may be permuted // because we multiply bf16 via `ReorderWidenMulAccumulate`, but this does // not change their horizontal sum. `buf` is thread-local space for 16 `VF`. template > HWY_INLINE void operator()(DF df, float scale, // VF C00, VF C01, VF C02, VF C03, // VF C10, VF C11, VF C12, VF C13, // VF C20, VF C21, VF C22, VF C23, // VF C30, VF C31, VF C32, VF C33, // float* HWY_RESTRICT buf, float* HWY_RESTRICT tile_c, size_t stride_c) const { const size_t N = hn::Lanes(df); // Horizontal reductions (`ReduceSum`) are rather expensive, entailing // log(N) operations for vectors of length N. Because kRegCols == 4, we can // instead use `StoreInterleaved4` for a vector length-agnostic 'transpose': // `buf[0, 4 * N)` holds C00[0], C01[0], C02[0], C03[0], // C00[1], C01[1], C02[1], C03[1] .. C00[N-1], C01[N-1], C02[N-1], C03[N-1]. MaybeStoreInterleaved4<0>(df, N, C00, C01, C02, C03, buf); MaybeStoreInterleaved4<1>(df, N, C10, C11, C12, C13, buf); MaybeStoreInterleaved4<2>(df, N, C20, C21, C22, C23, buf); MaybeStoreInterleaved4<3>(df, N, C30, C31, C32, C33, buf); // Adding N consecutive V4 yields four horizontal sums of Cr0, Cr1, Cr2, Cr3 // in the elements of one V4. We have four independent rows `r`, hence the // code is effectively unrolled, which increases throughput. const hn::FixedTag d4; using V4 = hn::Vec; V4 sum0 = MaybeLoad<0>(d4, N, buf); V4 sum1 = MaybeLoad<1>(d4, N, buf); V4 sum2 = MaybeLoad<2>(d4, N, buf); V4 sum3 = MaybeLoad<3>(d4, N, buf); for (size_t i = 1; i < N; ++i) { sum0 = MaybeAdd<0>(d4, N, sum0, buf + 4 * i); sum1 = MaybeAdd<1>(d4, N, sum1, buf + 4 * i); sum2 = MaybeAdd<2>(d4, N, sum2, buf + 4 * i); sum3 = MaybeAdd<3>(d4, N, sum3, buf + 4 * i); } // Scale, then store to four elements per row of `tile_c`. const V4 vscale = hn::Set(d4, scale); MaybeMulAdd<0>(d4, sum0, vscale, tile_c, stride_c); MaybeMulAdd<1>(d4, sum1, vscale, tile_c, stride_c); MaybeMulAdd<2>(d4, sum2, vscale, tile_c, stride_c); MaybeMulAdd<3>(d4, sum3, vscale, tile_c, stride_c); } }; // Streams a `kNumRows` high strip of `A` and the transposed `B`, then writes a // *finished* tile of f32 `C` whose top left is (row_ac, row_b_col_c). // TODO: loop over sections instead of full rows and accumulate into `tile_c`. template HWY_INLINE void MatMulTile(const Mat& A, const Mat& B, const size_t row_ac, const size_t row_b_col_c, const float scale, const float* HWY_RESTRICT add, float* HWY_RESTRICT buf, const Mat& C) { // For 'decompressing' A and B into BF16 or float. const hn::ScalableTag dm; using VM = hn::Vec; const size_t NM = hn::Lanes(dm); static_assert(kRegRows == 4); const BRow<0, MatTB> b_row0(B, row_b_col_c); const BRow<1, MatTB> b_row1(B, row_b_col_c); const BRow<2, MatTB> b_row2(B, row_b_col_c); const BRow<3, MatTB> b_row3(B, row_b_col_c); const ALoadAccumulate<0, MatTA> a_row0(A, row_ac); const ALoadAccumulate<1, MatTA> a_row1(A, row_ac); const ALoadAccumulate<2, MatTA> a_row2(A, row_ac); const ALoadAccumulate<3, MatTA> a_row3(A, row_ac); const hn::Repartition df; using VF = hn::Vec; VF C00, C01, C02, C03; VF C10, C11, C12, C13; VF C20, C21, C22, C23; VF C30, C31, C32, C33; { // First iteration initializes the `Crc` vectors. VM b00, b01, b10, b11, b20, b21, b30, b31; b_row0.Load2(dm, /*col_ab=*/0, b00, b01); b_row1.Load2(dm, /*col_ab=*/0, b10, b11); b_row2.Load2(dm, /*col_ab=*/0, b20, b21); b_row3.Load2(dm, /*col_ab=*/0, b30, b31); a_row0.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, C00, C01, C02, C03); a_row1.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, C10, C11, C12, C13); a_row2.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, C20, C21, C22, C23); a_row3.template First(dm, b00, b01, b10, b11, b20, b21, b30, b31, C30, C31, C32, C33); } // `2 * NM` per iteration because `Load2` returns two vectors. HWY_UNROLL(1) for (size_t col_ab = 2 * NM; col_ab <= A.cols - 2 * NM; col_ab += 2 * NM) { VM b00, b01, b10, b11, b20, b21, b30, b31; b_row0.Load2(dm, col_ab, b00, b01); b_row1.Load2(dm, col_ab, b10, b11); b_row2.Load2(dm, col_ab, b20, b21); b_row3.Load2(dm, col_ab, b30, b31); a_row0.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C00, C01, C02, C03); a_row1.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C10, C11, C12, C13); a_row2.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C20, C21, C22, C23); a_row3.template Next(dm, col_ab, b00, b01, b10, b11, b20, b21, b30, b31, C30, C31, C32, C33); } // TODO: hoist into outer loop. float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c; InitC(add, row_b_col_c, C_tile, C.stride); AddHorizontalSums()(df, scale, C00, C01, C02, C03, C10, C11, C12, C13, C20, C21, C22, C23, C30, C31, C32, C33, buf, C_tile, C.stride); } // Computes the matrix product `A * B * scale [+ add]` and stores it in `C`. // // `A` is a row-major matrix of shape `(batch_size, A.cols)`. // `B` is transposed; `B.cols`, which must match `A.cols`, denotes the number of // rows in the original B, and `C.cols` the number of columns in the original B. // // `scale` allows expanding the smaller range of `SfpStream` to the original // values. When `A` and/or `B` are from CompressedArray, `scale` should be the // product of their `.scale()` values, otherwise 1.0f. // // If `kAdd` is true, the row-vector `add` is added to each row of `C`, // otherwise `add` is ignored and can be nullptr. A scale for `add` is not // supported, so make sure its scale is 1. // // `C` is a row-major matrix of size `(batch_size, C.cols)`. // // Updates 4x4 tiles of C in parallel using a work-stealing thread pool. // Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k. // Must not be called concurrently with the same `env`. template HWY_NOINLINE void MatMul(const size_t batch_size, const Mat& A, const Mat& B, const float scale, const float* HWY_RESTRICT add, MatMulEnv& env, const Mat& C) { // PROFILER_ZONE("Matmul"); HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty()); HWY_DASSERT(A.cols == B.cols); // Must be a multiple of two vectors because we Decompress2. HWY_DASSERT(A.cols % (2 * hn::Lanes(hn::ScalableTag())) == 0); HWY_DASSERT(C.cols % kRegCols == 0); // We currently write C directly, which touches more memory than fits in L3. // TODO: add another level of loops to finish L3-sized pieces of C at a time. const size_t tilesY = hwy::DivCeil(batch_size, kRegRows); const size_t tilesX = C.cols / kRegCols; env.Pool().Run( 0, tilesX * tilesY, [&](const uint64_t idx_tile, size_t thread) HWY_ATTR { // TODO: when using PerClusterPool, compute lp from outer and inner. float* HWY_RESTRICT buf = env.Buf(thread); const size_t tx = idx_tile % tilesX; const size_t ty = idx_tile / tilesX; const size_t row_ac = ty * kRegRows; const size_t row_b_col_c = tx * kRegCols; // How many rows of C are left to compute. If more than 4, this // tile still only computes 4 rows. const size_t num_rows = batch_size - row_ac; HWY_DASSERT(num_rows != 0); switch (num_rows) { case 1: MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); break; case 2: MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); break; case 3: MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); break; default: MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, buf, C); } }); } // NOLINTNEXTLINE(google-readability-namespace-comments) } // namespace HWY_NAMESPACE } // namespace gcpp HWY_AFTER_NAMESPACE(); #endif // NOLINT