Fix DASSERT - TiledBatch requires at least 2 vectors.

Also use shorthand for weight types.

PiperOrigin-RevId: 643958371
This commit is contained in:
Jan Wassenberg 2024-06-17 04:27:22 -07:00 committed by Copybara-Service
parent 7dbfa44794
commit ad790d89d1
1 changed files with 42 additions and 41 deletions

View File

@ -13,13 +13,14 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <cstdio>
#include <memory> #include <memory>
#ifndef HWY_DISABLED_TARGETS #ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR #define HWY_DISABLED_TARGETS HWY_SCALAR
#endif #endif
#include <stddef.h> #include <stddef.h>
#include <stdio.h>
#include <algorithm> #include <algorithm>
#include <array> #include <array>
@ -539,48 +540,48 @@ void TestTiledBatchMatMul() {
} }
void TestAllTiledBatchMatMul() { void TestAllTiledBatchMatMul() {
using BF16 = hwy::bfloat16_t;
using F32 = float;
using SFP = SfpStream;
// medium-sized square test // medium-sized square test
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>(); TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>(); TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>(); TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>(); TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
SfpStream>();
// minimal non-square test // minimal non-square test. kK must be at least 2 vectors.
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>(); TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>(); TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
SfpStream>(); TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>(); TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>(); TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>(); TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>(); TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>(); TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>(); TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>(); TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>(); TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>(); TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>(); TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>(); TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>(); TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
SfpStream>(); TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>(); TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>(); TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>(); TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>(); TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
// large-scale test // large-scale test
// TODO(philculliton): investigate rounding issues with large matrices. // TODO(philculliton): investigate rounding issues with large matrices.