mirror of https://github.com/google/gemma.cpp.git
Fix DASSERT - TiledBatch requires at least 2 vectors.
Also use shorthand for weight types. PiperOrigin-RevId: 643958371
This commit is contained in:
parent
7dbfa44794
commit
ad790d89d1
|
|
@ -13,13 +13,14 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
#include <cstdio>
|
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
#include <stdio.h>
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <array>
|
#include <array>
|
||||||
|
|
@ -539,48 +540,48 @@ void TestTiledBatchMatMul() {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAllTiledBatchMatMul() {
|
void TestAllTiledBatchMatMul() {
|
||||||
|
using BF16 = hwy::bfloat16_t;
|
||||||
|
using F32 = float;
|
||||||
|
using SFP = SfpStream;
|
||||||
// medium-sized square test
|
// medium-sized square test
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, BF16>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, F32>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, float, SfpStream>();
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/false, F32, SFP>();
|
||||||
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, hwy::bfloat16_t,
|
TestTiledBatchMatMul<512, 512, 512, /*kAdd=*/true, BF16, SFP>();
|
||||||
SfpStream>();
|
|
||||||
|
|
||||||
// minimal non-square test
|
// minimal non-square test. kK must be at least 2 vectors.
|
||||||
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, float>();
|
TestTiledBatchMatMul<35, 128, 32, /*kAdd=*/false, F32>();
|
||||||
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<34, 128, 32, /*kAdd=*/true, BF16>();
|
||||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>();
|
||||||
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>();
|
||||||
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
TestTiledBatchMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>();
|
||||||
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, hwy::bfloat16_t,
|
TestTiledBatchMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
||||||
SfpStream>();
|
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float>();
|
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/true, float, SfpStream>();
|
TestTiledBatchMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>();
|
||||||
TestTiledBatchMatMul<4, 128, 8, /*kAdd=*/false, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
||||||
TestTiledBatchMatMul<3, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
|
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float>();
|
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/true, float, SfpStream>();
|
TestTiledBatchMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>();
|
||||||
TestTiledBatchMatMul<2, 128, 16, /*kAdd=*/false, hwy::bfloat16_t,
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32>();
|
||||||
SfpStream>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, hwy::bfloat16_t>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, float>();
|
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>();
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/false, float, SfpStream>();
|
|
||||||
TestTiledBatchMatMul<1, 128, 32, /*kAdd=*/true, hwy::bfloat16_t, SfpStream>();
|
|
||||||
|
|
||||||
// large-scale test
|
// large-scale test
|
||||||
// TODO(philculliton): investigate rounding issues with large matrices.
|
// TODO(philculliton): investigate rounding issues with large matrices.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue