mirror of https://github.com/google/gemma.cpp.git
parent
419dc34ed5
commit
c616abe628
183
gemma/ops.h
183
gemma/ops.h
|
|
@ -22,6 +22,7 @@
|
||||||
|
|
||||||
#include <array>
|
#include <array>
|
||||||
#include <cmath>
|
#include <cmath>
|
||||||
|
#include <cstdio>
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
|
|
@ -93,11 +94,179 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
return kRowsPerStrip;
|
return kRowsPerStrip;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Processes a single 4x4 tile of A x B. Shared between static and dynamic
|
||||||
|
// versions.
|
||||||
|
template <typename MatT, size_t kColsA>
|
||||||
|
HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
|
const MatT* HWY_RESTRICT B, MatT* HWY_RESTRICT C,
|
||||||
|
size_t tile_num, const int xtiles, const int lda,
|
||||||
|
const int ldb, const int ldc) {
|
||||||
|
constexpr int RM = 4;
|
||||||
|
constexpr int RN = 4;
|
||||||
|
|
||||||
|
// Calculate chunk start coords.
|
||||||
|
int ii = tile_num / xtiles * RM;
|
||||||
|
int jj = tile_num % xtiles * RN;
|
||||||
|
|
||||||
|
const hn::ScalableTag<MatT> d;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
|
||||||
|
V c00 = hn::Zero(d);
|
||||||
|
V c01 = hn::Zero(d);
|
||||||
|
V c02 = hn::Zero(d);
|
||||||
|
V c03 = hn::Zero(d);
|
||||||
|
|
||||||
|
V c10 = hn::Zero(d);
|
||||||
|
V c11 = hn::Zero(d);
|
||||||
|
V c12 = hn::Zero(d);
|
||||||
|
V c13 = hn::Zero(d);
|
||||||
|
|
||||||
|
V c20 = hn::Zero(d);
|
||||||
|
V c21 = hn::Zero(d);
|
||||||
|
V c22 = hn::Zero(d);
|
||||||
|
V c23 = hn::Zero(d);
|
||||||
|
|
||||||
|
V c30 = hn::Zero(d);
|
||||||
|
V c31 = hn::Zero(d);
|
||||||
|
V c32 = hn::Zero(d);
|
||||||
|
V c33 = hn::Zero(d);
|
||||||
|
|
||||||
|
// Steps down the rows of A and B, and across width (kN) in steps of
|
||||||
|
// N (Lanes()). Accumulates into the cache vectors. hn::ReduceSum() is
|
||||||
|
// called on each of the cache vectors to sum the partial sums into C.
|
||||||
|
for (size_t l = 0; l < kColsA; l += N) {
|
||||||
|
V k0 = hn::LoadU(d, B + ldb * (jj + 0) + l);
|
||||||
|
V k1 = hn::LoadU(d, B + ldb * (jj + 1) + l);
|
||||||
|
V k2 = hn::LoadU(d, B + ldb * (jj + 2) + l);
|
||||||
|
V k3 = hn::LoadU(d, B + ldb * (jj + 3) + l);
|
||||||
|
|
||||||
|
V a0 = hn::LoadU(d, A + lda * (ii + 0) + l);
|
||||||
|
c00 = hn::MulAdd(a0, k0, c00);
|
||||||
|
c01 = hn::MulAdd(a0, k1, c01);
|
||||||
|
c02 = hn::MulAdd(a0, k2, c02);
|
||||||
|
c03 = hn::MulAdd(a0, k3, c03);
|
||||||
|
|
||||||
|
V a1 = hn::LoadU(d, A + lda * (ii + 1) + l);
|
||||||
|
c10 = hn::MulAdd(a1, k0, c10);
|
||||||
|
c11 = hn::MulAdd(a1, k1, c11);
|
||||||
|
c12 = hn::MulAdd(a1, k2, c12);
|
||||||
|
c13 = hn::MulAdd(a1, k3, c13);
|
||||||
|
|
||||||
|
V a2 = hn::LoadU(d, A + lda * (ii + 2) + l);
|
||||||
|
c20 = hn::MulAdd(a2, k0, c20);
|
||||||
|
c21 = hn::MulAdd(a2, k1, c21);
|
||||||
|
c22 = hn::MulAdd(a2, k2, c22);
|
||||||
|
c23 = hn::MulAdd(a2, k3, c23);
|
||||||
|
|
||||||
|
V a3 = hn::LoadU(d, A + lda * (ii + 3) + l);
|
||||||
|
c30 = hn::MulAdd(a3, k0, c30);
|
||||||
|
c31 = hn::MulAdd(a3, k1, c31);
|
||||||
|
c32 = hn::MulAdd(a3, k2, c32);
|
||||||
|
c33 = hn::MulAdd(a3, k3, c33);
|
||||||
|
}
|
||||||
|
|
||||||
|
C[ldc * (ii + 0) + (jj + 0)] = hn::ReduceSum(d, c00);
|
||||||
|
C[ldc * (ii + 0) + (jj + 1)] = hn::ReduceSum(d, c01);
|
||||||
|
C[ldc * (ii + 0) + (jj + 2)] = hn::ReduceSum(d, c02);
|
||||||
|
C[ldc * (ii + 0) + (jj + 3)] = hn::ReduceSum(d, c03);
|
||||||
|
|
||||||
|
C[ldc * (ii + 1) + (jj + 0)] = hn::ReduceSum(d, c10);
|
||||||
|
C[ldc * (ii + 1) + (jj + 1)] = hn::ReduceSum(d, c11);
|
||||||
|
C[ldc * (ii + 1) + (jj + 2)] = hn::ReduceSum(d, c12);
|
||||||
|
C[ldc * (ii + 1) + (jj + 3)] = hn::ReduceSum(d, c13);
|
||||||
|
|
||||||
|
C[ldc * (ii + 2) + (jj + 0)] = hn::ReduceSum(d, c20);
|
||||||
|
C[ldc * (ii + 2) + (jj + 1)] = hn::ReduceSum(d, c21);
|
||||||
|
C[ldc * (ii + 2) + (jj + 2)] = hn::ReduceSum(d, c22);
|
||||||
|
C[ldc * (ii + 2) + (jj + 3)] = hn::ReduceSum(d, c23);
|
||||||
|
|
||||||
|
C[ldc * (ii + 3) + (jj + 0)] = hn::ReduceSum(d, c30);
|
||||||
|
C[ldc * (ii + 3) + (jj + 1)] = hn::ReduceSum(d, c31);
|
||||||
|
C[ldc * (ii + 3) + (jj + 2)] = hn::ReduceSum(d, c32);
|
||||||
|
C[ldc * (ii + 3) + (jj + 3)] = hn::ReduceSum(d, c33);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tiled 4x4 GEMM. Covers primary M =4..512, k = 3k/24k, n = 24k/3k use case.
|
||||||
|
// This version uses tiling suitable for static scheduling.
|
||||||
|
// Note: expects transposed / shuffled B.
|
||||||
|
template <size_t kM, size_t kColsA, size_t kK, typename MatT>
|
||||||
|
void GEMM_4x4_Static(const MatT* HWY_RESTRICT A, const MatT* HWY_RESTRICT B,
|
||||||
|
MatT* HWY_RESTRICT C) {
|
||||||
|
const hn::ScalableTag<MatT> d;
|
||||||
|
const size_t N = hn::Lanes(d); // column step size
|
||||||
|
constexpr int RM = 4; // tile height
|
||||||
|
constexpr int RN = 4; // tile width
|
||||||
|
|
||||||
|
HWY_ASSERT(kM % RM == 0);
|
||||||
|
HWY_ASSERT(kColsA % N == 0);
|
||||||
|
HWY_ASSERT(kColsA % RN == 0);
|
||||||
|
|
||||||
|
int lda = kColsA;
|
||||||
|
int ldb = kColsA; // n instead of k because we're transposing
|
||||||
|
int ldc = kK;
|
||||||
|
|
||||||
|
int ytiles = (kM) / RM;
|
||||||
|
int xtiles = (kK) / RN; // k instead of n because we're transposing
|
||||||
|
int tiles = xtiles * ytiles;
|
||||||
|
|
||||||
|
for (int job = 0; job < tiles; ++job) {
|
||||||
|
GEMM_4x4_Tile<MatT, kColsA>(A, B, C, job, xtiles, lda, ldb, ldc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Tiled 4x4 GEMM. Covers primary M =4..512, k = 3k/24k, n = 24k/3k use case.
|
||||||
|
// This version uses tiling and pooled threads.
|
||||||
|
// Note: expects transposed / shuffled B.
|
||||||
|
template <size_t kM, size_t kColsA, size_t kK, typename MatT>
|
||||||
|
HWY_NOINLINE void MatMul_4x4_Impl(const MatT* HWY_RESTRICT A,
|
||||||
|
const MatT* HWY_RESTRICT B,
|
||||||
|
MatT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
||||||
|
// Process 4x4 chunks of C in parallel. Each pool thread handles a single A x
|
||||||
|
// B tile. Note that C is being addressed directly without a buffer, and that
|
||||||
|
// the cache vectors (c00, c01, etc.) are being summed directly into C. There
|
||||||
|
// may be additional stability / speed gains to be made by using a buffer.
|
||||||
|
const hn::ScalableTag<MatT> d;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
|
||||||
|
const int lda = kColsA;
|
||||||
|
const int ldb = kColsA; // n instead of k because we're transposing
|
||||||
|
const int ldc = kK;
|
||||||
|
|
||||||
|
// 4x4
|
||||||
|
const int RM = 4;
|
||||||
|
const int RN = 4;
|
||||||
|
|
||||||
|
const int ytiles = (kM) / RM;
|
||||||
|
const int xtiles = (kK) / RN; // k instead of n because we're transposing
|
||||||
|
const int tiles = xtiles * ytiles;
|
||||||
|
|
||||||
|
// 4x4 case requires kM % 4 == 0, kN % N == 0, kK % 4 == 0
|
||||||
|
HWY_ASSERT(kM % RM == 0);
|
||||||
|
HWY_ASSERT(kColsA % N == 0);
|
||||||
|
HWY_ASSERT(kColsA % RN == 0);
|
||||||
|
HWY_ASSERT(kK % RN == 0);
|
||||||
|
HWY_ASSERT(kColsA >= N);
|
||||||
|
|
||||||
|
// Handles a single 4x4 chunk, which is completed and then written into C.
|
||||||
|
pool.Run(0, tiles, [&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
|
||||||
|
GEMM_4x4_Tile<MatT, kColsA>(A, B, C, chunk, xtiles, lda, ldb, ldc);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Requires m % 4 == 0, n % Lanes() == 0, k % 4 == 0
|
||||||
|
template <size_t kM, size_t kN, size_t kK, typename MatT>
|
||||||
|
HWY_INLINE void MatMul_4x4(const MatT* HWY_RESTRICT a,
|
||||||
|
const MatT* HWY_RESTRICT b, MatT* HWY_RESTRICT out,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
|
return MatMul_4x4_Impl<kM, kN, kK, MatT>(a, b, out, pool);
|
||||||
|
}
|
||||||
|
|
||||||
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
// Largely unoptimized; reordered innermost loops nets ~5-10X speedup on
|
||||||
// ops_test across instruction sets.
|
// ops_test across instruction sets.
|
||||||
template <size_t kM, size_t kN, size_t kK>
|
template <size_t kM, size_t kN, size_t kK, typename MatT>
|
||||||
HWY_INLINE void MatMul(const float* HWY_RESTRICT a, const float* HWY_RESTRICT b,
|
HWY_INLINE void MatMul(const MatT* HWY_RESTRICT a, const MatT* HWY_RESTRICT b,
|
||||||
float* HWY_RESTRICT out) {
|
MatT* HWY_RESTRICT out) {
|
||||||
int i, j, k;
|
int i, j, k;
|
||||||
for (i = 0; i < kM; ++i) {
|
for (i = 0; i < kM; ++i) {
|
||||||
for (k = 0; k < kN; ++k) {
|
for (k = 0; k < kN; ++k) {
|
||||||
|
|
@ -167,8 +336,8 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
||||||
namespace detail {
|
namespace detail {
|
||||||
|
|
||||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||||
// of row i with `vec_aligned` and add into `out[i]`. The upper-left coordinate
|
// of row i with `vec_aligned` and add into `out[i]`. The upper-left
|
||||||
// of the tile is r0, c0.
|
// coordinate of the tile is r0, c0.
|
||||||
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
|
template <bool kVecEO, class DF, typename ArrayT, typename VecT>
|
||||||
HWY_INLINE void AccumulatePartialDotProducts(
|
HWY_INLINE void AccumulatePartialDotProducts(
|
||||||
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
DF df, const ArrayT& mat, size_t mat_ofs, size_t mat_stride, size_t r0,
|
||||||
|
|
@ -208,8 +377,8 @@ HWY_INLINE void SetFirstPartialDotProducts(DF df, const ArrayT& mat,
|
||||||
|
|
||||||
// Adds together partial dot products for all tiles with the same r0 (a
|
// Adds together partial dot products for all tiles with the same r0 (a
|
||||||
// horizontal strip of the entire matrix); the result is the full dot product
|
// horizontal strip of the entire matrix); the result is the full dot product
|
||||||
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we store
|
// for rows r in [r0, r0 + num_rows) + optionally the add vector, which we
|
||||||
// into in out[r - r0].
|
// store into in out[r - r0].
|
||||||
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
|
template <bool kVecEO, bool kAdd, class DF, typename ArrayT, typename VecT,
|
||||||
typename AddT>
|
typename AddT>
|
||||||
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
HWY_INLINE void FullDotProductsForStrip(DF df, const ArrayT& mat,
|
||||||
|
|
|
||||||
|
|
@ -350,9 +350,44 @@ CompressedArray<float, kOuter * kInner> GenerateMat(size_t offset) {
|
||||||
const float scale = 1.0f / kInner;
|
const float scale = 1.0f / kInner;
|
||||||
for (size_t i = 0; i < kOuter; i++) {
|
for (size_t i = 0; i < kOuter; i++) {
|
||||||
for (size_t j = 0; j < kInner; j++) {
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
content[i * kInner + j] = static_cast<float>((i + j + offset) * scale);
|
content[i * kInner + j] =
|
||||||
|
static_cast<float>((i * kInner + j + offset) * scale);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// for (size_t i = 0; i < kOuter; i++) {
|
||||||
|
// for (size_t j = 0; j < kInner; j++) {
|
||||||
|
// fprintf(stderr, "content[%lu] = %f\n", i * kInner + j,
|
||||||
|
// content[i * kInner + j]);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
Compress(content, ws, mat, pool);
|
||||||
|
mat.set_scale(1.0f);
|
||||||
|
return mat;
|
||||||
|
}
|
||||||
|
|
||||||
|
template <size_t kOuter, size_t kInner>
|
||||||
|
CompressedArray<float, kOuter * kInner> GenerateTransposeMat(size_t offset) {
|
||||||
|
hwy::ThreadPool pool(0);
|
||||||
|
gcpp::CompressWorkingSet ws;
|
||||||
|
CompressedArray<float, kOuter * kInner> mat;
|
||||||
|
std::array<float, kOuter * kInner> content;
|
||||||
|
const float scale = 1.0f / kInner;
|
||||||
|
for (size_t i = 0; i < kOuter; i++) {
|
||||||
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
|
content[j * kOuter + i] =
|
||||||
|
static_cast<float>((i * kInner + j + offset) * scale);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// for (size_t i = 0; i < kOuter; i++) {
|
||||||
|
// for (size_t j = 0; j < kInner; j++) {
|
||||||
|
// fprintf(stderr, "content[%lu] = %f (transpose)\n", i * kInner + j,
|
||||||
|
// content[i * kInner + j]);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
Compress(content, ws, mat, pool);
|
Compress(content, ws, mat, pool);
|
||||||
mat.set_scale(1.0f);
|
mat.set_scale(1.0f);
|
||||||
return mat;
|
return mat;
|
||||||
|
|
@ -403,6 +438,7 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatMul(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -445,7 +481,7 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
|
||||||
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
double actual_value = hwy::ConvertScalarTo<double>(actual[idx]);
|
||||||
|
|
||||||
const double tolerance =
|
const double tolerance =
|
||||||
expected_value * 20 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
|
expected_value * 21 * 1.0 / (1ULL << hwy::MantissaBits<MatT>());
|
||||||
|
|
||||||
if (!(expected_value - tolerance <= actual_value &&
|
if (!(expected_value - tolerance <= actual_value &&
|
||||||
actual_value <= expected_value + tolerance)) {
|
actual_value <= expected_value + tolerance)) {
|
||||||
|
|
@ -456,11 +492,11 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<MatT[]>& expected,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatMul() {
|
void TestTiledMatMul() {
|
||||||
hwy::ThreadPool pool(0);
|
hwy::ThreadPool pool(0);
|
||||||
constexpr size_t kM = 128 * 3; // 384
|
constexpr size_t kM = 512; // 384
|
||||||
constexpr size_t kK = 128 * 5; // 640
|
constexpr size_t kN = 512; // * 5; // 6; // 768
|
||||||
constexpr size_t kN = 128 * 6; // 768
|
constexpr size_t kK = 512; // * 5; // 640
|
||||||
|
|
||||||
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
|
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
|
||||||
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
|
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
|
||||||
|
|
@ -478,6 +514,37 @@ void TestMatMul() {
|
||||||
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
Decompress(compressed_c, 0, c.get(), kM * kK);
|
Decompress(compressed_c, 0, c.get(), kM * kK);
|
||||||
|
|
||||||
|
CompressedArray<float, kN * kK> b1_trans = GenerateTransposeMat<kN, kK>(0);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> b_trans =
|
||||||
|
hwy::AllocateAligned<float>(kN * kK);
|
||||||
|
Decompress(b1_trans, 0, b_trans.get(), kN * kK);
|
||||||
|
MatMul_4x4<kM, kN, kK>(a.get(), b_trans.get(), c.get(), pool);
|
||||||
|
|
||||||
|
AssertClose(expected_out1, c, kM * kK);
|
||||||
|
}
|
||||||
|
|
||||||
|
void TestMatMul() {
|
||||||
|
constexpr size_t kM = 512; // 384
|
||||||
|
constexpr size_t kN = 512; // * 5; // 6; // 768
|
||||||
|
constexpr size_t kK = 512; // * 5; // 640
|
||||||
|
|
||||||
|
CompressedArray<float, kM * kN> a1 = GenerateMat<kM, kN>(0);
|
||||||
|
CompressedArray<float, kN * kK> b1 = GenerateMat<kN, kK>(0);
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> a = hwy::AllocateAligned<float>(kM * kN);
|
||||||
|
Decompress(a1, 0, a.get(), kM * kN);
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kN * kK);
|
||||||
|
Decompress(b1, 0, b.get(), kN * kK);
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> expected_out1 =
|
||||||
|
SimpleMatMul<kM, kN, kK>(a, b);
|
||||||
|
|
||||||
|
CompressedArray<float, kM * kK> compressed_c = GenerateZeroMat<kM, kK>(0);
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> c = hwy::AllocateAligned<float>(kM * kK);
|
||||||
|
Decompress(compressed_c, 0, c.get(), kM * kK);
|
||||||
|
|
||||||
|
Decompress(b1, 0, b.get(), kN * kK);
|
||||||
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
|
MatMul<kM, kN, kK>(a.get(), b.get(), c.get());
|
||||||
|
|
||||||
AssertClose(expected_out1, c, kM * kK);
|
AssertClose(expected_out1, c, kM * kK);
|
||||||
|
|
@ -583,6 +650,7 @@ HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConstAndAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllSoftmax);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllCreateDistribution);
|
||||||
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTiledMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatMul);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestMatVecAdd);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestTwoMatVecAdd);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue