From 28ca001d5e08cc82be9e4c9271f85e1b91cae139 Mon Sep 17 00:00:00 2001 From: Phil Culliton Date: Fri, 3 May 2024 06:35:57 -0700 Subject: [PATCH] Matmul and test functions PiperOrigin-RevId: 630373984 --- gemma/ops.h | 15 ++++++++ gemma/ops_test.cc | 87 +++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/gemma/ops.h b/gemma/ops.h index df36be5..a520056 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -93,6 +93,21 @@ HWY_INLINE constexpr size_t RowsPerStrip() { return kRowsPerStrip; } +// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on +// ops_test across instruction sets. +template +HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b, + float* HWY_RESTRICT out) { + int i, j, k; + for (i = 0; i < kM; ++i) { + for (k = 0; k < kN; ++k) { + for (j = 0; j < kK; ++j) { + out[i * kK + j] += a[i * kN + k] * b[k * kK + j]; + } + } + } +} + HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned, const size_t size, float* HWY_RESTRICT out) { const hn::ScalableTag df; diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index 973d598..75a09d0 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -17,6 +17,8 @@ #define HWY_DISABLED_TARGETS HWY_SCALAR #endif +#include + #include #include #include @@ -376,6 +378,25 @@ CompressedArray GenerateMat(size_t offset) { return mat; } +template +CompressedArray GenerateZeroMat(size_t offset) { + hwy::ThreadPool pool(static_cast(std::clamp( + static_cast(std::thread::hardware_concurrency()) - 2, 1, 4))); + gcpp::CompressWorkingSet ws; + CompressedArray mat; + std::array content; + + pool.Run(0, kOuter, [&](const size_t i, size_t thread) { + for (size_t j = 0; j < kInner; j++) { + content[i * kInner + j] = 0.0f; + } + }); + + Compress(content, ws, mat, pool); + mat.set_scale(1.0f); + return mat; +} + template hwy::AlignedFreeUniquePtr GenerateVec(size_t offset) { hwy::AlignedFreeUniquePtr vec = hwy::AllocateAligned(length); @@ -386,6 +407,25 @@ hwy::AlignedFreeUniquePtr GenerateVec(size_t offset) { return vec; } +// A simple matrix multiplication. No optimization / tiling. +template +hwy::AlignedFreeUniquePtr SimpleMatMul( + const hwy::AlignedFreeUniquePtr& a, + const hwy::AlignedFreeUniquePtr& b) { + hwy::AlignedFreeUniquePtr out = hwy::AllocateAligned(kM * kK); + hwy::ZeroBytes(out.get(), kM * kK * sizeof(float)); + + int i, j, k; + for (i = 0; i < kM; ++i) { + for (j = 0; j < kK; ++j) { + for (k = 0; k < kN; ++k) { + out[i * kK + j] += a[i * kN + k] * b[k * kK + j]; + } + } + } + return out; +} + template hwy::AlignedFreeUniquePtr SimpleMatVecAdd( const CompressedArray& mat, @@ -417,6 +457,52 @@ void AssertClose(const hwy::AlignedFreeUniquePtr& a, } } +template +void AssertClose(const hwy::AlignedFreeUniquePtr& expected, + const hwy::AlignedFreeUniquePtr& actual, size_t num) { + for (size_t idx = 0; idx < num; idx++) { + double expected_value = hwy::ConvertScalarTo(expected[idx]); + double actual_value = hwy::ConvertScalarTo(actual[idx]); + + const double tolerance = + expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits()); + + if (!(expected_value - tolerance <= actual_value && + actual_value <= expected_value + tolerance)) { + fprintf(stderr, "expected[%lu]: %f, actual[%lu]: %f\n", idx, + expected_value, idx, actual_value); + HWY_ASSERT(0); + } + } +} + +void TestMatMul() { + hwy::ThreadPool pool(0); + constexpr size_t kM = 128 * 3; // 384 + constexpr size_t kK = 128 * 5; // 640 + constexpr size_t kN = 128 * 6; // 768 + + CompressedArray a1 = GenerateMat(0); + CompressedArray b1 = GenerateMat(0); + + hwy::AlignedFreeUniquePtr a = hwy::AllocateAligned(kM * kN); + Decompress(a1, 0, a.get(), kM * kN); + + hwy::AlignedFreeUniquePtr b = hwy::AllocateAligned(kN * kK); + Decompress(b1, 0, b.get(), kN * kK); + + hwy::AlignedFreeUniquePtr expected_out1 = + SimpleMatMul(a, b); + + CompressedArray compressed_c = GenerateZeroMat(0); + hwy::AlignedFreeUniquePtr c = hwy::AllocateAligned(kM * kK); + Decompress(compressed_c, 0, c.get(), kM * kK); + + MatMul(a.get(), b.get(), c.get()); + + AssertClose(expected_out1, c, kM * kK); +} + void TestMatVecAdd() { hwy::ThreadPool pool(0); constexpr size_t kOuter = 128 * 3; @@ -518,6 +604,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax); HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution); +HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul); HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd); HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoOfsMatVecAddLoop);