mirror of https://github.com/google/gemma.cpp.git
Fix DASSERT - TiledBatch requires at least 2 vectors.
Also use shorthand for weight types. PiperOrigin-RevId: 643958371
This commit is contained in:
parent
7dbfa44794
commit
ad790d89d1
|
|
@ -13,13 +13,14 @@
|
|||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
|
||||
#ifndef HWY_DISABLED_TARGETS
|
||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||
#endif
|
||||
|
||||
#include <stddef.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
|
|
@ -539,48 +540,48 @@ void TestTiledBatchMatMul() {
|
|||
}
|
||||
|
||||
void TestAllTiledBatchMatMul() {
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
using F32 = float;
|
||||
using SFP = SfpStream;
|
||||
// medium-sized square test
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t,
|
||||
SfpStream>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
|
||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
|
||||
|
||||
// minimal non-square test
|
||||
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>();
|
||||
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
||||
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t,
|
||||
SfpStream>();
|
||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>();
|
||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>();
|
||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
|
||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>();
|
||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>();
|
||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t,
|
||||
SfpStream>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
|
||||
// minimal non-square test. kK must be at least 2 vectors.
|
||||
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
|
||||
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
|
||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
|
||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
|
||||
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
|
||||
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
|
||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
|
||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
|
||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
|
||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
|
||||
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
|
||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
|
||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
|
||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
|
||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
|
||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
|
||||
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
|
||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
||||
|
||||
// large-scale test
|
||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||
|
|
|
|||
Loading…
Reference in New Issue