mirror of https://github.com/google/gemma.cpp.git
Remove allocation from GEMM_4x4_Tile when decoding compressed weights by implementing
SfpCodec::Dec2F and ComressTraits<T>::Decompress2 for all supported types. It also allows to remove one of the specializations of GEMM_4x4_Tile, handling compressed MatB with one function. As before even when MatA is bf16 it is using 32-bit registers for computations. Measurements for a 2b-it sfp-encoded model on a AMD Ryzen Threadripper PRO 3945WX 12-Cores: baseline: ``` 32.6254 prefill tokens / sec 8.91429 tokens / sec 115 milliseconds time to first token ``` this change: ``` 54.3045 prefill tokens / sec 16.8191 tokens / sec 56 milliseconds time to first token ``` PiperOrigin-RevId: 651369694
This commit is contained in:
parent
f519ab6693
commit
3e92088595
|
|
@ -25,6 +25,7 @@ cc_library(
|
|||
hdrs = ["gemma/ops.h"],
|
||||
deps = [
|
||||
"//compression:compress",
|
||||
"//compression:sfp",
|
||||
"@hwy//:algo",
|
||||
"@hwy//:dot",
|
||||
"@hwy//:hwy",
|
||||
|
|
@ -49,6 +50,7 @@ cc_test(
|
|||
"//compression:compress",
|
||||
"@hwy//:hwy",
|
||||
"@hwy//:hwy_test_util",
|
||||
"@hwy//:nanobenchmark",
|
||||
"@hwy//:thread_pool",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -79,6 +79,15 @@ struct CompressTraits<float> {
|
|||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const size_t N = hn::Lanes(df);
|
||||
f0 = hn::LoadU(df, in + in_ofs);
|
||||
f1 = hn::LoadU(df, in + in_ofs + N);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
|
|
@ -88,8 +97,8 @@ struct CompressTraits<float> {
|
|||
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
|
||||
|
||||
for (size_t i = 0; i < num; i += 2 * N) {
|
||||
const VF in0 = hn::LoadU(df, in + in_ofs + i);
|
||||
const VF in1 = hn::LoadU(df, in + in_ofs + i + N);
|
||||
VF in0, in1;
|
||||
Decompress2(df, in, in_ofs + i, in0, in1);
|
||||
hn::StoreU(in0, df, out + i);
|
||||
hn::StoreU(in1, df, out + i + N);
|
||||
}
|
||||
|
|
@ -174,6 +183,17 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
const VBF in16 = hn::LoadU(dbf, in + in_ofs);
|
||||
f0 = hn::PromoteLowerTo(df, in16);
|
||||
f1 = hn::PromoteUpperTo(df, in16);
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
|
|
@ -186,9 +206,8 @@ struct CompressTraits<hwy::bfloat16_t> {
|
|||
size_t i = 0;
|
||||
if (num >= N16) {
|
||||
for (i = 0; i <= num - N16; i += N16) {
|
||||
const VBF in16 = hn::LoadU(dbf, in + in_ofs + i);
|
||||
const VF in0 = hn::PromoteLowerTo(df, in16);
|
||||
const VF in1 = hn::PromoteUpperTo(df, in16);
|
||||
VF in0, in1;
|
||||
Decompress2(df, in, in_ofs + i, in0, in1);
|
||||
hn::StoreU(in0, df, out + i);
|
||||
hn::StoreU(in1, df, out + i + N16 / 2);
|
||||
}
|
||||
|
|
@ -296,6 +315,16 @@ struct CompressTraits<SfpStream> {
|
|||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF)>
|
||||
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
|
||||
size_t in_ofs, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Twice<hn::Rebind<uint8_t, DF>> d8;
|
||||
using V8 = hn::Vec<decltype(d8)>;
|
||||
const V8 packed = hn::LoadU(d8, &in->byte + in_ofs);
|
||||
SfpCodec::Dec2F(df, packed, f0, f1);
|
||||
}
|
||||
|
||||
template <class D, typename OutT>
|
||||
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/,
|
||||
const MatT* HWY_RESTRICT in, size_t in_ofs,
|
||||
|
|
|
|||
|
|
@ -535,6 +535,18 @@ class SfpCodec {
|
|||
}
|
||||
}
|
||||
|
||||
template <class DF, HWY_IF_F32_D(DF),
|
||||
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
|
||||
static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec<DF>& f0,
|
||||
hn::Vec<DF>& f1) {
|
||||
const hn::Rebind<hwy::bfloat16_t, DF> dbf;
|
||||
using VBF = hn::Vec<decltype(dbf)>;
|
||||
VBF bf0, bf1;
|
||||
Dec2B(dbf, packed, bf0, bf1);
|
||||
f0 = hn::PromoteTo(df, bf0);
|
||||
f1 = hn::PromoteTo(df, bf1);
|
||||
}
|
||||
|
||||
private:
|
||||
// Wrappers to avoid code duplication across float/bf16 input types and
|
||||
// the main loop/remainder.
|
||||
|
|
|
|||
279
gemma/ops.h
279
gemma/ops.h
|
|
@ -272,9 +272,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(
|
||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||
StoreHorizontalSumsAdd<kNumRows>(d, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
tile_add, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
|
|
@ -433,10 +433,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(
|
||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33, dd,
|
||||
tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
|
|
@ -444,120 +444,31 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
|||
}
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (f32, compressed).
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||
HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1),
|
||||
typename AddT>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C,
|
||||
const AddT* add,
|
||||
const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<MatTA> d;
|
||||
const size_t N = Lanes(d);
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
|
||||
V c00 = hn::Zero(d);
|
||||
V c01 = hn::Zero(d);
|
||||
V c02 = hn::Zero(d);
|
||||
V c03 = hn::Zero(d);
|
||||
|
||||
V c10 = hn::Zero(d);
|
||||
V c11 = hn::Zero(d);
|
||||
V c12 = hn::Zero(d);
|
||||
V c13 = hn::Zero(d);
|
||||
|
||||
V c20 = hn::Zero(d);
|
||||
V c21 = hn::Zero(d);
|
||||
V c22 = hn::Zero(d);
|
||||
V c23 = hn::Zero(d);
|
||||
|
||||
V c30 = hn::Zero(d);
|
||||
V c31 = hn::Zero(d);
|
||||
V c32 = hn::Zero(d);
|
||||
V c33 = hn::Zero(d);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
|
||||
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
|
||||
CompressTraits<MatTB>::Decompress(
|
||||
d,
|
||||
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
|
||||
kRegRows * kColsA_RowsB);
|
||||
const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get();
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
|
||||
|
||||
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(
|
||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_c, stride_c);
|
||||
}
|
||||
template <class VF>
|
||||
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||
const VF& b01, const VF& b10, const VF& b11,
|
||||
const VF& b20, const VF& b21, const VF& b30,
|
||||
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = MulAdd(a0, b00, c0);
|
||||
c0 = MulAdd(a1, b01, c0);
|
||||
c1 = MulAdd(a0, b10, c1);
|
||||
c1 = MulAdd(a1, b11, c1);
|
||||
c2 = MulAdd(a0, b20, c2);
|
||||
c2 = MulAdd(a1, b21, c2);
|
||||
c3 = MulAdd(a0, b30, c3);
|
||||
c3 = MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
// Same as above, but with mixed Mat types: (bf16, compressed).
|
||||
// Same as above, for when there exists CompressTraits<MatTA>::Decompress2 and
|
||||
// MatTB is compressed.
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1),
|
||||
typename AddT>
|
||||
typename MatTB, HWY_IF_T_SIZE(MatTB, 1), typename AddT>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C,
|
||||
const AddT* add,
|
||||
const size_t idx_tile,
|
||||
const size_t xtiles, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
float* HWY_RESTRICT C, const AddT* add,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
|
@ -567,13 +478,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||
|
||||
const hn::ScalableTag<float> d32;
|
||||
const size_t N = Lanes(d32);
|
||||
const size_t N = hn::Lanes(d32);
|
||||
using V = hn::Vec<decltype(d32)>;
|
||||
// TODO: Using half-vectors for now, it might be faster to
|
||||
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
|
||||
// accordingly.
|
||||
const hn::Rebind<MatTA, decltype(d32)> d16;
|
||||
HWY_DASSERT(Lanes(d16) == Lanes(d32));
|
||||
|
||||
V c00 = hn::Zero(d32);
|
||||
V c01 = hn::Zero(d32);
|
||||
|
|
@ -595,67 +501,70 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
V c32 = hn::Zero(d32);
|
||||
V c33 = hn::Zero(d32);
|
||||
|
||||
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const size_t tile_a_ofs = stride_a * row_a;
|
||||
const size_t tile_b_ofs = stride_b * row_b_col_c;
|
||||
|
||||
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
|
||||
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
|
||||
CompressTraits<MatTB>::Decompress(
|
||||
d32,
|
||||
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
|
||||
kRegRows * kColsA_RowsB);
|
||||
const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get();
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Loop over columns of A and columns of the transposed B, in steps of 2*N
|
||||
// (since we are decoding consecutive bytes at each iteration).
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
const V b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
|
||||
size_t col_ab = 0;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) {
|
||||
V b00, b01;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
|
||||
V b10, b11;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
|
||||
V b20, b21;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
|
||||
V b30, b31;
|
||||
CompressTraits<MatTB>::Decompress2(
|
||||
d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
|
||||
|
||||
V a00, a01;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
|
||||
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
|
||||
const V a0 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
|
||||
c00 = hn::MulAdd(a0, b0, c00);
|
||||
c01 = hn::MulAdd(a0, b1, c01);
|
||||
c02 = hn::MulAdd(a0, b2, c02);
|
||||
c03 = hn::MulAdd(a0, b3, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const V a1 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
|
||||
c10 = hn::MulAdd(a1, b0, c10);
|
||||
c11 = hn::MulAdd(a1, b1, c11);
|
||||
c12 = hn::MulAdd(a1, b2, c12);
|
||||
c13 = hn::MulAdd(a1, b3, c13);
|
||||
V a10, a11;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
|
||||
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const V a2 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
|
||||
c20 = hn::MulAdd(a2, b0, c20);
|
||||
c21 = hn::MulAdd(a2, b1, c21);
|
||||
c22 = hn::MulAdd(a2, b2, c22);
|
||||
c23 = hn::MulAdd(a2, b3, c23);
|
||||
V a20, a21;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
|
||||
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const V a3 =
|
||||
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
|
||||
c30 = hn::MulAdd(a3, b0, c30);
|
||||
c31 = hn::MulAdd(a3, b1, c31);
|
||||
c32 = hn::MulAdd(a3, b2, c32);
|
||||
c33 = hn::MulAdd(a3, b3, c33);
|
||||
V a30, a31;
|
||||
CompressTraits<MatTA>::Decompress2(
|
||||
d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
|
||||
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32,
|
||||
c33, dd, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_c, stride_c);
|
||||
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
tile_c, stride_c);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -755,10 +664,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32,
|
||||
c33, dd, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
|
|
@ -861,10 +770,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
if constexpr (kAdd) {
|
||||
const AddT* tile_add = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
c30, c31, c32, c33, tile_add, tile_c, stride_c);
|
||||
const AddT* dd = add + row_b_col_c;
|
||||
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32,
|
||||
c33, dd, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
|
||||
|
|
@ -906,20 +815,20 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
|||
HWY_ASSERT(num_rows > 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(
|
||||
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
||||
kStrideA, kStrideB, kStrideC);
|
||||
break;
|
||||
case 2:
|
||||
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(
|
||||
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
||||
kStrideA, kStrideB, kStrideC);
|
||||
break;
|
||||
case 3:
|
||||
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(
|
||||
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
||||
kStrideA, kStrideB, kStrideC);
|
||||
break;
|
||||
default:
|
||||
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(
|
||||
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
|
||||
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
||||
kStrideA, kStrideB, kStrideC);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue