From ad790d89d1fb5b2221c8cd2faf69bea9629c20e0 Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Mon, 17 Jun 2024 04:27:22 -0700 Subject: [PATCH] Fix DASSERT - TiledBatch requires at least 2 vectors. Also use shorthand for weight types. PiperOrigin-RevId: 643958371 --- gemma/ops_test.cc | 83 ++++++++++++++++++++++++----------------------- 1 file changed, 42 insertions(+), 41 deletions(-) diff --git a/gemma/ops_test.cc b/gemma/ops_test.cc index b6b3405..5641efa 100644 --- a/gemma/ops_test.cc +++ b/gemma/ops_test.cc @@ -13,13 +13,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include #include + #ifndef HWY_DISABLED_TARGETS #define HWY_DISABLED_TARGETS HWY_SCALAR #endif #include +#include #include #include @@ -539,48 +540,48 @@ void TestTiledBatchMatMul() { } void TestAllTiledBatchMatMul() { + using BF16 = hwy::bfloat16_t; + using F32 = float; + using SFP = SfpStream; // medium-sized square test - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>(); - TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, - SfpStream>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>(); + TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>(); - // minimal non-square test - TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>(); - TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); - TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>(); - TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, - SfpStream>(); - TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>(); - TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>(); - TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>(); - TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>(); - TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>(); - TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>(); - TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>(); - TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>(); - TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, - SfpStream>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>(); - TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>(); + // minimal non-square test. kK must be at least 2 vectors. + TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>(); + TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>(); + TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(); + TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(); + TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(); + TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(); + TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>(); + TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>(); + TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(); + TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(); + TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(); + TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(); + TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(); + TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>(); + TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>(); + TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(); + TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(); + TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(); + TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(); + TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(); // large-scale test // TODO(philculliton): investigate rounding issues with large matrices.