gemma.cpp/ops/matmul_test.cc

439 lines
18 KiB
C++

// Copyright 2023 Google LLC
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// End to end test of MatMul, comparing against a reference implementation.
#include "hwy/detect_compiler_arch.h"
#ifndef HWY_DISABLED_TARGETS
// Exclude HWY_SCALAR due to 2x bf16 -> f32, and Armv7 NEON because we require
// double-precision support.
#if HWY_ARCH_ARM_V7
#define HWY_DISABLED_TARGETS (HWY_SCALAR | HWY_NEON)
#else
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#endif
#include <stddef.h>
#include <stdio.h>
#include <memory>
#include "compression/compress.h"
#include "compression/shared.h"
#include "ops/matmul.h"
#include "util/allocator.h"
#include "util/basics.h"
#include "util/threading.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "ops/matmul_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
// After highway.h
#include "compression/compress-inl.h"
#include "ops/dot-inl.h"
#include "ops/matmul-inl.h"
#include "hwy/tests/test_util-inl.h"
HWY_BEFORE_NAMESPACE();
namespace gcpp {
// For running TestTiny only once. Defined within HWY_ONCE.
extern int64_t first_target;
namespace HWY_NAMESPACE {
namespace hn = hwy::HWY_NAMESPACE;
using FloatPtr = hwy::AlignedFreeUniquePtr<float[]>;
template <typename MatT>
using MatStoragePtr = std::unique_ptr<MatStorageT<MatT>>;
// Generates inputs: deterministic, within max SfpStream range.
template <typename MatT>
MatStoragePtr<MatT> GenerateMat(const Extents2D& extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("mat", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
HWY_ASSERT(content);
const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(r * extents.cols + c) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
mat->set_scale(0.6f); // Arbitrary value, different from 1.
return mat;
}
// extents describes the transposed matrix.
template <typename MatT>
MatStoragePtr<MatT> GenerateTransposedMat(const Extents2D extents,
hwy::ThreadPool& pool) {
gcpp::CompressWorkingSet ws;
auto mat =
std::make_unique<MatStorageT<MatT>>("trans", extents.rows, extents.cols);
FloatPtr content = hwy::AllocateAligned<float>(mat->NumElements());
const float scale = SfpStream::kMax / (mat->NumElements());
pool.Run(0, extents.rows, [&](const size_t r, size_t /*thread*/) {
for (size_t c = 0; c < extents.cols; c++) {
float f = static_cast<float>(c * extents.rows + r) * scale;
if ((r + c) & 1) f = -f; // Also generate some negative values.
content[r * extents.cols + c] = f;
}
});
CompressScaled(content.get(), mat->NumElements(), ws, *mat, pool);
// Arbitrary value, different from 1, must match GenerateMat.
mat->set_scale(0.6f);
return mat;
}
// Returns 1-norm, used for estimating tolerable numerical differences.
double MaxRowAbsSum(const RowVectorBatch<float>& a) {
double max_row_abs_sum = 0.0;
for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a.Batch(r);
double row_abs_sum = 0.0;
for (size_t c = 0; c < a.Cols(); c++) {
row_abs_sum += hwy::ScalarAbs(row[c]);
}
max_row_abs_sum = HWY_MAX(max_row_abs_sum, row_abs_sum);
}
return max_row_abs_sum;
}
// Returns the maximum absolute value of `a`.
float MaxAbs(const RowVectorBatch<float>& a) {
float max_abs = 0.0f;
for (size_t c = 0; c < a.Cols(); c++) {
for (size_t r = 0; r < a.BatchSize(); r++) {
const float* row = a.Batch(r);
max_abs = HWY_MAX(max_abs, hwy::ScalarAbs(row[c]));
}
}
return max_abs;
}
// B is already transposed.
template <typename TA, typename TB, typename TC>
void AssertClose(const ConstMat<TA>& A, const ConstMat<TB>& B,
const RowPtr<TC>& C_slow, const RowPtr<TC>& C, int line) {
const hn::ScalableTag<float> df;
const size_t cols = A.extents.cols;
const size_t B_rows = B.extents.rows;
// Round up for DecompressAndZeroPad.
RowVectorBatch<float> a_batch = AllocateAlignedRows<float>(A.extents);
RowVectorBatch<float> b_trans_batch = AllocateAlignedRows<float>(B.extents);
RowVectorBatch<float> c_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
RowVectorBatch<float> c_slow_batch =
AllocateAlignedRows<float>(Extents2D(A.extents.rows, B_rows));
HWY_ASSERT(A.ofs == 0 && B.ofs == 0);
for (size_t m = 0; m < A.extents.rows; ++m) {
DecompressAndZeroPad(df, MakeSpan(A.ptr + A.Row(m), cols), 0,
a_batch.Batch(m), cols);
DecompressAndZeroPad(df, MakeSpan(C.Row(m), B_rows), 0, c_batch.Batch(m),
B_rows);
DecompressAndZeroPad(df, MakeSpan(C_slow.Row(m), B_rows), 0,
c_slow_batch.Batch(m), B_rows);
}
for (size_t n = 0; n < B_rows; ++n) {
DecompressAndZeroPad(df, MakeSpan(B.ptr + B.Row(n), cols), 0,
b_trans_batch.Batch(n), cols);
}
// MatMul rounds inputs to BF16, so error is proportional to the max input
// magnitude, but also to f32 accumulation of rows in A and B.
const double norm = MaxRowAbsSum(a_batch) * MaxRowAbsSum(b_trans_batch);
const float max_abs = MaxAbs(a_batch) * MaxAbs(b_trans_batch);
const double eps_bf16 = hwy::ConvertScalarTo<double>(hwy::Epsilon<BF16>());
const double eps_f32 = hwy::ConvertScalarTo<double>(hwy::Epsilon<float>());
double tolerance = 12 * norm * eps_f32;
// Dot() also rounds F32,BF16 to BF16, but not with F32,F32, so increase the
// tolerance there.
if (IsF32<TA>() && IsF32<TB>()) {
tolerance += 4 * max_abs * eps_bf16;
}
if (tolerance > 500.0) {
HWY_WARN("high tolerance %f norm %f maxabs %f\n", tolerance, norm, max_abs);
}
const double max_rel = 1.0 + hwy::ConvertScalarTo<double>(hwy::Epsilon<TC>());
for (size_t r = 0; r < A.extents.rows; r++) {
const float* expected_row = c_slow_batch.Batch(r);
const float* actual_row = c_batch.Batch(r);
for (size_t c = 0; c < B.extents.rows; c++) {
const double expected_value = static_cast<double>(expected_row[c]);
const double actual_value = static_cast<double>(actual_row[c]);
const bool in_range = expected_value - tolerance <= actual_value &&
actual_value <= expected_value + tolerance;
if (!in_range) {
const double max = HWY_MAX(expected_value, actual_value);
const double min = HWY_MIN(expected_value, actual_value);
const double rel = max / HWY_MAX(min, 1E-6);
if (rel > max_rel) {
hwy::Abort(__FILE__, line,
"(%zu,%zu): expected %f, actual %f, norm %f maxabs %f "
"tolerance %f rel %E max_rel %E\n",
r, c, expected_value, actual_value, norm, max_abs,
tolerance, rel, max_rel);
}
}
}
}
}
// B is already transposed.
template <typename TA, typename TB, typename TC>
HWY_INLINE void MatMulSlow(const ConstMat<TA> A, const ConstMat<TB> B,
const float* HWY_RESTRICT add_row, MatMulEnv& env,
const RowPtr<TC>& C) {
// TA can be any Packed except NuqStream because it uses pointer
// arithmetic, because it is the second argument to Dot, which does not
// support a v_ofs.
static_assert(sizeof(TA) >= sizeof(BF16), "A matrix must be BF16/f32");
const float scale = A.scale * B.scale;
const hn::ScalableTag<float> df; // lane type is ignored
const PackedSpan<const TB> b_span =
MakeSpan(B.ptr, B.ofs + B.Stride() * B.Extents().rows);
const IndexRange all_rows_c(0, A.Extents().rows);
const IndexRange all_cols_c(0, C.Cols());
NestedPools& pools = env.parallel.Pools();
hwy::ThreadPool& all_packages = pools.AllPackages();
const IndexRangePartition get_row_c =
StaticPartition(all_rows_c, all_packages.NumWorkers(), 1);
ParallelizeOneRange(
get_row_c, all_packages,
[&](const IndexRange& rows_c, size_t package_idx) HWY_ATTR {
hwy::ThreadPool& all_clusters = pools.AllClusters(package_idx);
const size_t multiple = Allocator::QuantumBytes() / sizeof(TB);
const IndexRangePartition get_col_c =
StaticPartition(all_cols_c, all_clusters.NumWorkers(), multiple);
ParallelizeOneRange(
get_col_c, all_clusters,
[&](const IndexRange& cols_c, size_t cluster_idx) HWY_ATTR {
for (size_t r : rows_c) {
TC* HWY_RESTRICT C_row = C.Row(r);
for (size_t c : cols_c) {
const float add = add_row ? add_row[c] : 0.0f;
C_row[c] = hwy::ConvertScalarTo<TC>(
add + scale * Dot(df, b_span, c * B.Stride(),
A.ptr + A.Row(r), A.extents.cols));
}
}
});
});
}
void PrintSpeed(const char* algo, const Extents2D& A_extents,
const Extents2D& B_extents, double elapsed) {
const size_t num_b = B_extents.Area();
// 2x because of FMA.
fprintf(stderr, " %10s: %f seconds, %.1f GFLOPS.\n", algo,
elapsed, 2 * 1E-9 * A_extents.rows * num_b / elapsed);
}
template <typename TA, typename TB = TA, typename TC = float>
void TestMatMul(size_t rows_ac, size_t cols_a_rows_b, size_t cols_bc, bool add,
MatMulEnv& env, int line) {
hwy::ThreadPool& pool = env.parallel.Pools().Pool();
fprintf(stderr, "TestMatMul %zu, K=%zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n",
rows_ac, cols_a_rows_b, cols_bc, add, TypeName<TA>(), TypeName<TB>(),
TypeName<TC>());
env.print_config = false; // Too verbose.
env.print_best = true;
const Extents2D A_extents(rows_ac, cols_a_rows_b);
const Extents2D B_extents(cols_bc, cols_a_rows_b); // already transposed
const Extents2D C_extents(rows_ac, cols_bc);
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
HWY_ASSERT(a && b_trans);
std::unique_ptr<MatStorageT<float>> add_storage;
if (add) {
add_storage = GenerateMat<float>(Extents2D(1, cols_bc), pool);
HWY_ASSERT(add_storage);
add_storage->set_scale(1.0f);
}
const auto A = ConstMatFromWeights(*a);
const auto B = ConstMatFromWeights(*b_trans);
const float* add_row = add ? add_storage->data_scale1() : nullptr;
const RowPtr<TC> C_slow = RowPtrFromBatch(c_slow_batch);
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
MatMulSlow(A, B, add_row, env, C_slow);
// A few reps to get coverage of the various autotuned code paths.
for (size_t rep = 0; rep < 16; ++rep) {
MMPerKey* per_key = MatMul(A, B, add_row, env, C);
AssertClose(A, B, C_slow, C, line);
if (per_key->autotune.Best()) break;
}
}
using F32 = float;
using SFP = SfpStream;
// Sweep all dimensions for a single input type and Highway target, to verify
// the remainder handling.
void TestTiny() {
if (first_target == 0) first_target = HWY_TARGET;
if (HWY_TARGET != first_target) return;
for (size_t max_packages : {1, 2}) {
const BoundedTopology topology(BoundedSlice(0, max_packages));
Allocator::Init(topology, /*enable_bind=*/true);
const size_t max_threads = 0; // no limit
NestedPools pools(topology, max_threads, Tristate::kDefault);
#if GEMMA_DISABLE_TOPOLOGY
if (max_packages == 2) break; // we only have one package
#else
// If less than the limit, we have already tested all num_packages.
if (topology.FullTopology().packages.size() < max_packages) break;
#endif
fprintf(stderr, "TestTiny %zu: %s %s\n", max_packages,
topology.TopologyString(), pools.PinString());
Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
MatMulEnv env(topology, pools);
for (size_t M = 1; M <= 12; ++M) {
for (size_t K = 1; K <= 64; K *= 2) {
for (size_t N = 4; N <= 64; N += max_packages * 4) {
TestMatMul<F32, F32, F32>(M, K, N, /*add=*/false, env, __LINE__);
}
}
}
pools.MaybeStopSpinning(use_spinning);
}
}
void TestAllMatMul() {
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
return;
}
const BoundedTopology topology;
Allocator::Init(topology, /*enable_bind=*/true);
NestedPools pools(topology);
Tristate use_spinning = Tristate::kDefault;
pools.MaybeStartSpinning(use_spinning);
MatMulEnv env(topology, pools);
// Sizes seen in gemma_test 2B. Too slow for CI, enable on-demand.
TestMatMul<F32>(1, 2048, 512, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 2048, 2048, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 2048, 16384, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 16384, 2048, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(1, 2048, 256000, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 2048, 512, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 2048, 2048, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 2048, 16384, /*add=*/false, env, __LINE__);
// TestMatMul<F32>(5, 16384, 2048, /*add=*/false, env, __LINE__);
// medium-sized square, f32 vs bf16 for A, B, C; plus add.
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<F32, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16, F32>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<BF16, BF16, BF16>(256, 256, 256, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(256, 256, 256, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(256, 256, 256, /*add=*/true, env, __LINE__);
// minimal non-square test. kColsARowsB must be at least 2 vectors.
TestMatMul<F32>(35, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(34, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(33, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(33, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(31, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(29, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(4, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<BF16, SFP>(4, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<F32>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(3, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(3, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32, BF16>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16, F32>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32, SFP>(2, 128, 64, /*add=*/true, env, __LINE__);
TestMatMul<BF16, SFP>(2, 128, 64, /*add=*/false, env, __LINE__);
TestMatMul<F32>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, BF16>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, F32>(1, 128, 32, /*add=*/true, env, __LINE__);
TestMatMul<F32, SFP>(1, 128, 32, /*add=*/false, env, __LINE__);
TestMatMul<BF16, SFP>(1, 128, 32, /*add=*/true, env, __LINE__);
}
// NOLINTNEXTLINE(google-readability-namespace-comments)
} // namespace HWY_NAMESPACE
} // namespace gcpp
HWY_AFTER_NAMESPACE();
#if HWY_ONCE
namespace gcpp {
int64_t first_target = 0; // none run yet
HWY_BEFORE_TEST(MatMulTest);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestTiny);
HWY_EXPORT_AND_TEST_P(MatMulTest, TestAllMatMul);
HWY_AFTER_TEST();
} // namespace gcpp
#endif