mirror of https://github.com/google/gemma.cpp.git
Implement the matmul op with Onednn to leverage AMX optimization.
PiperOrigin-RevId: 683370269
This commit is contained in:
parent
2c28b18eb0
commit
029f2d3e98
|
|
@ -83,6 +83,8 @@ cc_library(
|
||||||
"@hwy//:matvec",
|
"@hwy//:matvec",
|
||||||
"@hwy//:profiler",
|
"@hwy//:profiler",
|
||||||
"@hwy//:thread_pool",
|
"@hwy//:thread_pool",
|
||||||
|
"//third_party/intel_dnnl:dnnl",
|
||||||
|
"//third_party/tbb",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -121,7 +121,7 @@ struct RuntimeConfig {
|
||||||
const ImageTokens *image_tokens = nullptr;
|
const ImageTokens *image_tokens = nullptr;
|
||||||
|
|
||||||
// Whether to use thread spinning to reduce barrier synchronization latency.
|
// Whether to use thread spinning to reduce barrier synchronization latency.
|
||||||
bool use_spinning = true;
|
bool use_spinning = false;
|
||||||
|
|
||||||
// End-of-sequence token.
|
// End-of-sequence token.
|
||||||
int eos_id = EOS_ID;
|
int eos_id = EOS_ID;
|
||||||
|
|
|
||||||
157
ops/matmul-inl.h
157
ops/matmul-inl.h
|
|
@ -31,11 +31,20 @@
|
||||||
// After highway.h
|
// After highway.h
|
||||||
#include "compression/compress-inl.h"
|
#include "compression/compress-inl.h"
|
||||||
#include "hwy/contrib/math/math-inl.h"
|
#include "hwy/contrib/math/math-inl.h"
|
||||||
|
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl.hpp"
|
||||||
|
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl_common.hpp"
|
||||||
|
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl_types.h"
|
||||||
|
|
||||||
|
|
||||||
HWY_BEFORE_NAMESPACE();
|
HWY_BEFORE_NAMESPACE();
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
namespace HWY_NAMESPACE {
|
namespace HWY_NAMESPACE {
|
||||||
namespace hn = hwy::HWY_NAMESPACE;
|
namespace hn = hwy::HWY_NAMESPACE;
|
||||||
|
using namespace dnnl;
|
||||||
|
using tag = memory::format_tag;
|
||||||
|
using dt = memory::data_type;
|
||||||
|
using dnnl::primitive_attr;
|
||||||
|
using dnnl::reorder;
|
||||||
|
|
||||||
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
// The MatMul result C[r,c] is Dot(A.Row(r), B.Col(c)). To reduce the number of
|
||||||
// loads, we reuse the same A row for several B columns, which are also loaded
|
// loads, we reuse the same A row for several B columns, which are also loaded
|
||||||
|
|
@ -66,7 +75,9 @@ constexpr size_t kRegRows = kRegCols;
|
||||||
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
|
// expensive to add each `sum0` and `sum1`, hence we only 'decompress' A and B
|
||||||
// to bf16 if the native op is available. This will actually demote f32
|
// to bf16 if the native op is available. This will actually demote f32
|
||||||
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
|
// activations to bf16. Otherwise, we decompress to f32 and use normal FMA.
|
||||||
using MulT = hwy::If<HWY_NATIVE_DOT_BF16, BF16, float>;
|
// Update the MulT, so Highway matmul always covert inputs to bf16, which is
|
||||||
|
// matched with the dnnl matmul logic.
|
||||||
|
using MulT = BF16;
|
||||||
|
|
||||||
// Loads two vectors at a time with element type MulT from a row of transposed
|
// Loads two vectors at a time with element type MulT from a row of transposed
|
||||||
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
|
// B. Called in a loop over col_ab. No bounds checking because `kRow` is
|
||||||
|
|
@ -450,7 +461,7 @@ HWY_INLINE void MatMulTile(const size_t batch_size, const Mat<const MatTA>& A,
|
||||||
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
|
// Typically `batch_size` is 1..512, `A.cols` and `C.cols` are 3k or 24k.
|
||||||
// Must not be called concurrently with the same `env`.
|
// Must not be called concurrently with the same `env`.
|
||||||
template <bool kAdd, typename MatTA, typename MatTB>
|
template <bool kAdd, typename MatTA, typename MatTB>
|
||||||
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
HWY_NOINLINE void MatMul_hwy(const size_t batch_size, const Mat<const MatTA>& A,
|
||||||
const Mat<const MatTB>& B, const float scale,
|
const Mat<const MatTB>& B, const float scale,
|
||||||
const float* HWY_RESTRICT add, MatMulEnv& env,
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
const Mat<float>& C) {
|
const Mat<float>& C) {
|
||||||
|
|
@ -499,6 +510,148 @@ HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static memory::data_type DnnType();
|
||||||
|
|
||||||
|
/// Instantiation for float type. Add similar instantiations for other
|
||||||
|
/// type if needed.
|
||||||
|
template <>
|
||||||
|
memory::data_type DnnType<float>() {
|
||||||
|
return memory::data_type::f32;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
memory::data_type DnnType<BF16>() {
|
||||||
|
return memory::data_type::bf16;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <>
|
||||||
|
memory::data_type DnnType<SfpStream>() {
|
||||||
|
fprintf(stderr, "DnnType SfpStream is not supported\n");
|
||||||
|
return memory::data_type::bf16;
|
||||||
|
}
|
||||||
|
template <typename MatT>
|
||||||
|
memory convert_to_bf16(dnnl::engine engine, dnnl::stream engine_stream,
|
||||||
|
const Mat<const MatT>& source, memory::dims dims,
|
||||||
|
tag format_tag) {
|
||||||
|
// Write the input matrix to dnnl memory.
|
||||||
|
dnnl::memory::desc src_md(dims, DnnType<MatT>(), format_tag);
|
||||||
|
auto source_mem = memory(src_md, engine);
|
||||||
|
source_mem.set_data_handle(const_cast<MatT*>(source.ptr + source.ofs));
|
||||||
|
if (std::is_same<MatT, BF16>::value) {
|
||||||
|
return source_mem;
|
||||||
|
}
|
||||||
|
// When the input is float, convert it to BF16.
|
||||||
|
if (std::is_same<MatT, float>::value) {
|
||||||
|
auto dst_md = memory::desc(source_mem.get_desc().get_dims(),
|
||||||
|
dnnl::memory::data_type::bf16, format_tag);
|
||||||
|
dnnl::memory dst_mem(dst_md, engine);
|
||||||
|
auto reorder_pd =
|
||||||
|
reorder::primitive_desc(engine, source_mem.get_desc(), engine, dst_md);
|
||||||
|
auto reorder_prim = reorder(reorder_pd);
|
||||||
|
reorder_prim.execute(engine_stream,
|
||||||
|
{{DNNL_ARG_FROM, source_mem}, {DNNL_ARG_TO, dst_mem}});
|
||||||
|
return dst_mem;
|
||||||
|
}
|
||||||
|
fprintf(stderr, "Unsupported type\n");
|
||||||
|
return source_mem;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool kAdd, typename MatTA, typename MatTB>
|
||||||
|
HWY_NOINLINE void MatMul_dnnl(const size_t batch_size,
|
||||||
|
const Mat<const MatTA>& A,
|
||||||
|
const Mat<const MatTB>& B, const float scale,
|
||||||
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
|
const Mat<float>& C) {
|
||||||
|
dnnl::engine engine = env.engine;
|
||||||
|
dnnl::stream engine_stream = env.engine_stream;
|
||||||
|
|
||||||
|
// First stage: process the input data.
|
||||||
|
// OneDNN mandates that input and output data be managed using the
|
||||||
|
// dnnl::memory, a practice that can lead to enhanced performance.
|
||||||
|
// Create memory dims for inputs and outputs.
|
||||||
|
const memory::dim kRowsAC = batch_size, kColsARowsB = A.cols,
|
||||||
|
kColsBC = C.cols;
|
||||||
|
memory::dims src_dims = {kRowsAC, kColsARowsB};
|
||||||
|
memory::dims weights_dims = {kColsARowsB, kColsBC};
|
||||||
|
memory::dims bias_dims = {1, kColsBC};
|
||||||
|
memory::dims dst_dims = {kRowsAC, kColsBC};
|
||||||
|
|
||||||
|
auto src_format_tag = tag::ab;
|
||||||
|
// `B` is a transposed matrix.
|
||||||
|
auto weights_format_tag = tag::ba;
|
||||||
|
auto dest_format_tag = tag::ab;
|
||||||
|
|
||||||
|
// Create memory descriptors for inputs and outputs.
|
||||||
|
auto src_md = memory::desc(src_dims, DnnType<MulT>(), src_format_tag);
|
||||||
|
auto weights_md =
|
||||||
|
memory::desc(weights_dims, DnnType<MulT>(), weights_format_tag);
|
||||||
|
auto scale_md = memory::desc({{1}, dt::f32, tag::x});
|
||||||
|
auto dst_md = memory::desc(dst_dims, dt::f32, dest_format_tag);
|
||||||
|
|
||||||
|
auto src_mem =
|
||||||
|
convert_to_bf16(engine, engine_stream, A, src_dims, src_format_tag);
|
||||||
|
auto weights_mem = convert_to_bf16(engine, engine_stream, B, weights_dims,
|
||||||
|
weights_format_tag);
|
||||||
|
auto dst_mem = memory(dst_md, engine);
|
||||||
|
|
||||||
|
// Second stage: Create the matmul primitive/operation.
|
||||||
|
// Define matmul Primitive arguments.
|
||||||
|
std::unordered_map<int, memory> matmul_args;
|
||||||
|
matmul_args.insert({DNNL_ARG_SRC, src_mem});
|
||||||
|
matmul_args.insert({DNNL_ARG_WEIGHTS, weights_mem});
|
||||||
|
matmul_args.insert({DNNL_ARG_DST, dst_mem});
|
||||||
|
|
||||||
|
// Apply the scaling factor to the weights, enabling us to apply it
|
||||||
|
// to the multiplication results before incorporating the bias.
|
||||||
|
auto scale_mem = memory(scale_md, engine, const_cast<float*>(&scale));
|
||||||
|
matmul_args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, scale_mem});
|
||||||
|
dnnl::primitive_attr attr;
|
||||||
|
attr.set_scales_mask(DNNL_ARG_WEIGHTS, 0);
|
||||||
|
|
||||||
|
// Create primitive descriptor.
|
||||||
|
matmul::primitive_desc matmul_pd;
|
||||||
|
|
||||||
|
// When there is bias, add it to the matmul_args.
|
||||||
|
if (kAdd) {
|
||||||
|
auto bias_md = memory::desc(bias_dims, dt::f32, tag::ab);
|
||||||
|
auto bias_mem = memory(bias_md, engine);
|
||||||
|
bias_mem.set_data_handle(const_cast<float*>(add));
|
||||||
|
matmul_args.insert({DNNL_ARG_BIAS, bias_mem});
|
||||||
|
matmul_pd = matmul::primitive_desc(engine, src_md, weights_md, bias_md,
|
||||||
|
dst_md, attr);
|
||||||
|
} else {
|
||||||
|
matmul_pd =
|
||||||
|
matmul::primitive_desc(engine, src_md, weights_md, dst_md, attr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Third stage: Execute the matmul primitive/operation.
|
||||||
|
auto matmul_prim = matmul(matmul_pd);
|
||||||
|
matmul_prim.execute(engine_stream, matmul_args);
|
||||||
|
engine_stream.wait();
|
||||||
|
|
||||||
|
// Copy the output from dnnl memory to the output matrix.
|
||||||
|
// Adding padding when the C.stride is more than the C.cols.
|
||||||
|
auto c_mem_ptr = static_cast<float*>(dst_mem.get_data_handle());
|
||||||
|
const hn::ScalableTag<float> df;
|
||||||
|
for (int row = 0; row < batch_size; ++row) {
|
||||||
|
hn::StoreU(hn::Zero(df), df, C.ptr + row * C.stride);
|
||||||
|
std::copy(c_mem_ptr + row * C.cols, c_mem_ptr + (row + 1) * C.cols,
|
||||||
|
C.ptr + row * C.stride);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template <bool kAdd, typename MatTA, typename MatTB>
|
||||||
|
HWY_NOINLINE void MatMul(const size_t batch_size, const Mat<const MatTA>& A,
|
||||||
|
const Mat<const MatTB>& B, const float scale,
|
||||||
|
const float* HWY_RESTRICT add, MatMulEnv& env,
|
||||||
|
const Mat<float>& C) {
|
||||||
|
MatMul_dnnl<kAdd>(batch_size, A, B, scale, add, env, C);
|
||||||
|
|
||||||
|
// Enable the hwy matmul and disable the dnnl matmul, when we need to
|
||||||
|
// benchmark the hwy matmul.
|
||||||
|
// MatMul_hwy<kAdd>(batch_size, A, B, scale, add, env, C);
|
||||||
|
}
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
} // namespace HWY_NAMESPACE
|
} // namespace HWY_NAMESPACE
|
||||||
} // namespace gcpp
|
} // namespace gcpp
|
||||||
|
|
|
||||||
17
ops/matmul.h
17
ops/matmul.h
|
|
@ -18,11 +18,16 @@
|
||||||
|
|
||||||
#include <stddef.h>
|
#include <stddef.h>
|
||||||
|
|
||||||
|
#include "tbb/global_control.h"
|
||||||
#include "util/allocator.h" // RowVectorBatch
|
#include "util/allocator.h" // RowVectorBatch
|
||||||
#include "util/threading.h" // PerClusterPools
|
#include "util/threading.h" // PerClusterPools
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/per_target.h"
|
#include "hwy/per_target.h"
|
||||||
|
#include "third_party/intel_dnnl/include/oneapi/dnnl/dnnl.hpp"
|
||||||
|
|
||||||
|
using namespace dnnl;
|
||||||
|
using namespace tbb;
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
|
|
||||||
|
|
@ -81,8 +86,18 @@ class MatMulEnv {
|
||||||
const size_t num_lp = pools.NumLP();
|
const size_t num_lp = pools.NumLP();
|
||||||
const size_t NF = hwy::VectorBytes() / sizeof(float);
|
const size_t NF = hwy::VectorBytes() / sizeof(float);
|
||||||
buf_ = RowVectorBatch<float>(num_lp, 16 * NF);
|
buf_ = RowVectorBatch<float>(num_lp, 16 * NF);
|
||||||
|
setenv("ONEDNN_MAX_CPU_ISA", "AVX512_CORE_AMX", 1);
|
||||||
|
// Enable verbose logging for dnnl when we need to debug.
|
||||||
|
// setenv("DNNL_VERBOSE", "2", 2);
|
||||||
|
tbb::global_control global_limit(
|
||||||
|
tbb::global_control::max_allowed_parallelism, 128);
|
||||||
|
// Create execution dnnl::engine.
|
||||||
|
engine = dnnl::engine(dnnl::engine::kind::cpu, 0);
|
||||||
|
// Create dnnl::stream.
|
||||||
|
engine_stream = dnnl::stream(engine);
|
||||||
}
|
}
|
||||||
|
dnnl::stream engine_stream;
|
||||||
|
dnnl::engine engine;
|
||||||
float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); }
|
float* HWY_RESTRICT Buf(size_t lp) { return buf_.Batch(lp); }
|
||||||
PerClusterPools& Pools() const { return *pools_; }
|
PerClusterPools& Pools() const { return *pools_; }
|
||||||
hwy::ThreadPool& Pool() const { return pools_->Inner(0); }
|
hwy::ThreadPool& Pool() const { return pools_->Inner(0); }
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@
|
||||||
// See the License for the specific language governing permissions and
|
// See the License for the specific language governing permissions and
|
||||||
// limitations under the License.
|
// limitations under the License.
|
||||||
|
|
||||||
|
#include "compression/shared.h"
|
||||||
#ifndef HWY_DISABLED_TARGETS
|
#ifndef HWY_DISABLED_TARGETS
|
||||||
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
|
// Exclude HWY_SCALAR due to 2x bf16 -> f32.
|
||||||
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
#define HWY_DISABLED_TARGETS HWY_SCALAR
|
||||||
|
|
@ -145,7 +146,7 @@ void AssertClose(size_t rows_ac, size_t cols_ab, size_t cols_c_rows_b,
|
||||||
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
const double norm = MaxColAbsSum(a.get(), rows_ac, cols_ab) *
|
||||||
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
MaxColAbsSum(b_trans.get(), cols_c_rows_b, cols_ab);
|
||||||
// Dot(float,BF16) rounds both to BF16.
|
// Dot(float,BF16) rounds both to BF16.
|
||||||
using RefType = hwy::If<IsF32<MatTA>() && IsF32<MatTB>(), float, BF16>;
|
using RefType = BF16;
|
||||||
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
|
const double epsilon = hwy::ConvertScalarTo<double>(hwy::Epsilon<RefType>());
|
||||||
const double tolerance = 200.0 * norm * epsilon;
|
const double tolerance = 200.0 * norm * epsilon;
|
||||||
|
|
||||||
|
|
@ -233,8 +234,13 @@ void TestMatMul(MatMulEnv& env) {
|
||||||
std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow =
|
std::unique_ptr<CompressedArray<float, kRowsAC * kColsBC>> c_slow =
|
||||||
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
|
GenerateZeroMat<float, kRowsAC, kColsBC>(pool);
|
||||||
const double start_slow = hwy::platform::Now();
|
const double start_slow = hwy::platform::Now();
|
||||||
MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
|
// Compare the dnnl matmul results with the hwy matmul results.
|
||||||
kAdd ? add->data() : nullptr, env, c_slow->data());
|
MatMul_hwy<kAdd>(kRowsAC, ConstMat(a->data(), kColsARowsB),
|
||||||
|
ConstMat(b_trans->data(), kColsARowsB), scale,
|
||||||
|
kAdd ? add->data_scale1() : nullptr, env,
|
||||||
|
MutableMat(c_slow->data(), kColsBC));
|
||||||
|
// MatMulSlow(kRowsAC, kColsARowsB, kColsBC, a->data(), b_trans->data(), scale,
|
||||||
|
// kAdd ? add->data() : nullptr, c_slow->data());
|
||||||
if (want_bench) {
|
if (want_bench) {
|
||||||
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
|
PrintSpeed("MatMulSlow", kRowsAC, kColsARowsB, kColsBC,
|
||||||
hwy::platform::Now() - start_slow);
|
hwy::platform::Now() - start_slow);
|
||||||
|
|
@ -258,9 +264,9 @@ void TestMatMul(MatMulEnv& env) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestAllMatMul() {
|
void TestAllMatMul() {
|
||||||
|
tbb::global_control global_limit(tbb::global_control::max_allowed_parallelism, 128);
|
||||||
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
// Skip EMU128 (10x slower than SSE4 for SFP) and older x86.
|
||||||
if (HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SSE4 ||
|
if (HWY_TARGET != HWY_AVX3 && HWY_TARGET != HWY_AVX3_SPR) {
|
||||||
HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE2) {
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -272,10 +278,10 @@ void TestAllMatMul() {
|
||||||
using SFP = SfpStream;
|
using SFP = SfpStream;
|
||||||
|
|
||||||
// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
|
// large-scale test: batch_size=128 is better than 64 or 256 for SKX.
|
||||||
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, SFP>(env);
|
TestMatMul<128, 24576, 3072, /*kAdd=*/false, F32, BF16>(env);
|
||||||
TestMatMul<128, 3072, 24576, /*kAdd=*/false, F32, SFP>(env);
|
TestMatMul<128, 3072, 24576, /*kAdd=*/false, BF16>(env);
|
||||||
TestMatMul<1, 24576, 3072, /*kAdd=*/false, F32, F32>(env);
|
TestMatMul<1, 24576, 3072, /*kAdd=*/false, BF16>(env);
|
||||||
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, F32>(env);
|
TestMatMul<1, 3072, 24576, /*kAdd=*/false, F32, BF16>(env);
|
||||||
|
|
||||||
// medium-sized square test - temporarily disabled for faster testing.
|
// medium-sized square test - temporarily disabled for faster testing.
|
||||||
if constexpr (false) {
|
if constexpr (false) {
|
||||||
|
|
@ -292,32 +298,22 @@ void TestAllMatMul() {
|
||||||
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
|
TestMatMul<34, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||||
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
TestMatMul<33, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||||
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
TestMatMul<33, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||||
TestMatMul<31, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
|
||||||
TestMatMul<29, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
|
||||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
|
TestMatMul<4, 128, 32, /*kAdd=*/true, F32>(env);
|
||||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
|
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16>(env);
|
||||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
|
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, BF16>(env);
|
||||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
|
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, F32>(env);
|
||||||
TestMatMul<4, 128, 32, /*kAdd=*/true, F32, SFP>(env);
|
|
||||||
TestMatMul<4, 128, 32, /*kAdd=*/false, BF16, SFP>(env);
|
|
||||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
|
TestMatMul<3, 128, 32, /*kAdd=*/false, F32>(env);
|
||||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
|
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||||
TestMatMul<3, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
|
||||||
TestMatMul<3, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
|
||||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
|
TestMatMul<2, 128, 64, /*kAdd=*/true, F32>(env);
|
||||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
|
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16>(env);
|
||||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
|
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, BF16>(env);
|
||||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
|
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, F32>(env);
|
||||||
TestMatMul<2, 128, 64, /*kAdd=*/true, F32, SFP>(env);
|
|
||||||
TestMatMul<2, 128, 64, /*kAdd=*/false, BF16, SFP>(env);
|
|
||||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
|
TestMatMul<1, 128, 32, /*kAdd=*/false, F32>(env);
|
||||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
|
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16>(env);
|
||||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, BF16>(env);
|
||||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, F32>(env);
|
||||||
TestMatMul<1, 128, 32, /*kAdd=*/false, F32, SFP>(env);
|
|
||||||
TestMatMul<1, 128, 32, /*kAdd=*/true, BF16, SFP>(env);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
// NOLINTNEXTLINE(google-readability-namespace-comments)
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue