Fix DASSERT - TiledBatch requires at least 2 vectors.

Also use shorthand for weight types.

PiperOrigin-RevId: 643958371
This commit is contained in:
Jan Wassenberg 2024-06-17 04:27:22 -07:00 committed by Copybara-Service
parent 7dbfa44794
commit ad790d89d1
1 changed files with 42 additions and 41 deletions

View File

@ -13,13 +13,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cstdio>
#include <memory>
#ifndef HWY_DISABLED_TARGETS
#define HWY_DISABLED_TARGETS HWY_SCALAR
#endif
#include <stddef.h>
#include <stdio.h>
#include <algorithm>
#include <array>
@ -539,48 +540,48 @@ void TestTiledBatchMatMul() {
}
void TestAllTiledBatchMatMul() {
using BF16 = hwy::bfloat16_t;
using F32 = float;
using SFP = SfpStream;
// medium-sized square test
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t,
SfpStream>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
// minimal non-square test
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>();
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>();
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t,
SfpStream>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>();
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>();
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t,
SfpStream>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
// minimal non-square test. kK must be at least 2 vectors.
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
// large-scale test
// TODO(philculliton): investigate rounding issues with large matrices.