Remove allocation from GEMM_4x4_Tile when decoding compressed weights by implementing

SfpCodec::Dec2F and ComressTraits<T>::Decompress2 for all supported types. It also allows to remove one of the specializations of GEMM_4x4_Tile, handling compressed MatB with one function. As before even when MatA is bf16 it is using 32-bit registers for computations.

Measurements for a 2b-it sfp-encoded model on a  AMD Ryzen Threadripper PRO 3945WX 12-Cores:
baseline:
```
32.6254 prefill tokens / sec
8.91429 tokens / sec
115 milliseconds time to first token
```
this change:
```
54.3045 prefill tokens / sec
16.8191 tokens / sec
56 milliseconds time to first token
```
PiperOrigin-RevId: 651369694
This commit is contained in:
Andrey Vlasov 2024-07-11 05:13:03 -07:00 committed by Copybara-Service
parent f519ab6693
commit 3e92088595
4 changed files with 142 additions and 190 deletions

View File

@ -25,6 +25,7 @@ cc_library(
hdrs = ["gemma/ops.h"],
deps = [
"//compression:compress",
"//compression:sfp",
"@hwy//:algo",
"@hwy//:dot",
"@hwy//:hwy",
@ -49,6 +50,7 @@ cc_test(
"//compression:compress",
"@hwy//:hwy",
"@hwy//:hwy_test_util",
"@hwy//:nanobenchmark",
"@hwy//:thread_pool",
],
)

View File

@ -79,6 +79,15 @@ struct CompressTraits<float> {
}
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
size_t in_ofs, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const size_t N = hn::Lanes(df);
f0 = hn::LoadU(df, in + in_ofs);
f1 = hn::LoadU(df, in + in_ofs + N);
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
@ -88,8 +97,8 @@ struct CompressTraits<float> {
HWY_DASSERT(num >= 2 * N && num % (2 * N) == 0);
for (size_t i = 0; i < num; i += 2 * N) {
const VF in0 = hn::LoadU(df, in + in_ofs + i);
const VF in1 = hn::LoadU(df, in + in_ofs + i + N);
VF in0, in1;
Decompress2(df, in, in_ofs + i, in0, in1);
hn::StoreU(in0, df, out + i);
hn::StoreU(in1, df, out + i + N);
}
@ -174,6 +183,17 @@ struct CompressTraits<hwy::bfloat16_t> {
}
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
size_t in_ofs, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const hn::Repartition<hwy::bfloat16_t, decltype(df)> dbf;
using VBF = hn::Vec<decltype(dbf)>;
const VBF in16 = hn::LoadU(dbf, in + in_ofs);
f0 = hn::PromoteLowerTo(df, in16);
f1 = hn::PromoteUpperTo(df, in16);
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress(DF df, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,
@ -186,9 +206,8 @@ struct CompressTraits<hwy::bfloat16_t> {
size_t i = 0;
if (num >= N16) {
for (i = 0; i <= num - N16; i += N16) {
const VBF in16 = hn::LoadU(dbf, in + in_ofs + i);
const VF in0 = hn::PromoteLowerTo(df, in16);
const VF in1 = hn::PromoteUpperTo(df, in16);
VF in0, in1;
Decompress2(df, in, in_ofs + i, in0, in1);
hn::StoreU(in0, df, out + i);
hn::StoreU(in1, df, out + i + N16 / 2);
}
@ -296,6 +315,16 @@ struct CompressTraits<SfpStream> {
}
}
template <class DF, HWY_IF_F32_D(DF)>
static HWY_INLINE void Decompress2(DF df, const MatT* HWY_RESTRICT in,
size_t in_ofs, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const hn::Twice<hn::Rebind<uint8_t, DF>> d8;
using V8 = hn::Vec<decltype(d8)>;
const V8 packed = hn::LoadU(d8, &in->byte + in_ofs);
SfpCodec::Dec2F(df, packed, f0, f1);
}
template <class D, typename OutT>
static HWY_INLINE void Decompress(D d, size_t /*in_capacity*/,
const MatT* HWY_RESTRICT in, size_t in_ofs,

View File

@ -535,6 +535,18 @@ class SfpCodec {
}
}
template <class DF, HWY_IF_F32_D(DF),
class V8 = hn::Vec<hn::Twice<hn::Rebind<uint8_t, DF>>>>
static HWY_INLINE void Dec2F(DF df, V8 packed, hn::Vec<DF>& f0,
hn::Vec<DF>& f1) {
const hn::Rebind<hwy::bfloat16_t, DF> dbf;
using VBF = hn::Vec<decltype(dbf)>;
VBF bf0, bf1;
Dec2B(dbf, packed, bf0, bf1);
f0 = hn::PromoteTo(df, bf0);
f1 = hn::PromoteTo(df, bf1);
}
private:
// Wrappers to avoid code duplication across float/bf16 input types and
// the main loop/remainder.

View File

@ -272,9 +272,9 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_add, tile_c, stride_c);
StoreHorizontalSumsAdd<kNumRows>(d, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
tile_add, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
@ -433,10 +433,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_add, tile_c, stride_c);
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, dd,
tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
@ -444,120 +444,31 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
}
}
// Same as above, but with mixed Mat types: (f32, compressed).
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
HWY_IF_F32(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1),
typename AddT>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C,
const AddT* add,
const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
const size_t row_a = idx_tile / xtiles * kRegRows;
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<MatTA> d;
const size_t N = Lanes(d);
using V = hn::Vec<decltype(d)>;
V c00 = hn::Zero(d);
V c01 = hn::Zero(d);
V c02 = hn::Zero(d);
V c03 = hn::Zero(d);
V c10 = hn::Zero(d);
V c11 = hn::Zero(d);
V c12 = hn::Zero(d);
V c13 = hn::Zero(d);
V c20 = hn::Zero(d);
V c21 = hn::Zero(d);
V c22 = hn::Zero(d);
V c23 = hn::Zero(d);
V c30 = hn::Zero(d);
V c31 = hn::Zero(d);
V c32 = hn::Zero(d);
V c33 = hn::Zero(d);
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<MatTB>::Decompress(
d,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB);
const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get();
// Loop over columns of A and columns of the transposed B, in steps of N.
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if constexpr (kNumRows == 1) continue;
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
if constexpr (kNumRows == 2) continue;
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
if constexpr (kNumRows == 3) continue;
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_add, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
}
template <class VF>
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
const VF& b01, const VF& b10, const VF& b11,
const VF& b20, const VF& b21, const VF& b30,
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
c0 = MulAdd(a0, b00, c0);
c0 = MulAdd(a1, b01, c0);
c1 = MulAdd(a0, b10, c1);
c1 = MulAdd(a1, b11, c1);
c2 = MulAdd(a0, b20, c2);
c2 = MulAdd(a1, b21, c2);
c3 = MulAdd(a0, b30, c3);
c3 = MulAdd(a1, b31, c3);
}
// Same as above, but with mixed Mat types: (bf16, compressed).
// Same as above, for when there exists CompressTraits<MatTA>::Decompress2 and
// MatTB is compressed.
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
HWY_IF_BF16(MatTA), typename MatTB, HWY_IF_T_SIZE(MatTB, 1),
typename AddT>
typename MatTB, HWY_IF_T_SIZE(MatTB, 1), typename AddT>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C,
const AddT* add,
const size_t idx_tile,
const size_t xtiles, const size_t stride_a,
const size_t stride_b, const size_t stride_c) {
float* HWY_RESTRICT C, const AddT* add,
const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b,
const size_t stride_c) {
constexpr size_t kRegRows = 4;
constexpr size_t kRegCols = 4;
static_assert(kNumRows <= kRegRows);
@ -567,13 +478,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
const hn::ScalableTag<float> d32;
const size_t N = Lanes(d32);
const size_t N = hn::Lanes(d32);
using V = hn::Vec<decltype(d32)>;
// TODO: Using half-vectors for now, it might be faster to
// PromoteLower/UpperTo, and more so to PromoteEven/OddTo if we have packed B
// accordingly.
const hn::Rebind<MatTA, decltype(d32)> d16;
HWY_DASSERT(Lanes(d16) == Lanes(d32));
V c00 = hn::Zero(d32);
V c01 = hn::Zero(d32);
@ -595,67 +501,70 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
V c32 = hn::Zero(d32);
V c33 = hn::Zero(d32);
const MatTA* HWY_RESTRICT tile_a = A + stride_a * row_a;
const size_t tile_a_ofs = stride_a * row_a;
const size_t tile_b_ofs = stride_b * row_b_col_c;
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
CompressTraits<MatTB>::Decompress(
d32,
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
kRegRows * kColsA_RowsB);
const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get();
// Loop over columns of A and columns of the transposed B, in steps of N.
// Loop over columns of A and columns of the transposed B, in steps of 2*N
// (since we are decoding consecutive bytes at each iteration).
// Accumulates into the c## vectors.
HWY_UNROLL(1)
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
const V b0 = hn::LoadU(d32, tile_b + stride_b * 0 + col_ab);
const V b1 = hn::LoadU(d32, tile_b + stride_b * 1 + col_ab);
const V b2 = hn::LoadU(d32, tile_b + stride_b * 2 + col_ab);
const V b3 = hn::LoadU(d32, tile_b + stride_b * 3 + col_ab);
size_t col_ab = 0;
HWY_UNROLL(1)
for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) {
V b00, b01;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
V b10, b11;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
V b20, b21;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
V b30, b31;
CompressTraits<MatTB>::Decompress2(
d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
V a00, a01;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
c02, c03);
const V a0 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 0 + col_ab));
c00 = hn::MulAdd(a0, b0, c00);
c01 = hn::MulAdd(a0, b1, c01);
c02 = hn::MulAdd(a0, b2, c02);
c03 = hn::MulAdd(a0, b3, c03);
if constexpr (kNumRows == 1) continue;
const V a1 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 1 + col_ab));
c10 = hn::MulAdd(a1, b0, c10);
c11 = hn::MulAdd(a1, b1, c11);
c12 = hn::MulAdd(a1, b2, c12);
c13 = hn::MulAdd(a1, b3, c13);
V a10, a11;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
c12, c13);
if constexpr (kNumRows == 2) continue;
const V a2 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 2 + col_ab));
c20 = hn::MulAdd(a2, b0, c20);
c21 = hn::MulAdd(a2, b1, c21);
c22 = hn::MulAdd(a2, b2, c22);
c23 = hn::MulAdd(a2, b3, c23);
V a20, a21;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
c22, c23);
if constexpr (kNumRows == 3) continue;
const V a3 =
hn::PromoteTo(d32, hn::LoadU(d16, tile_a + stride_a * 3 + col_ab));
c30 = hn::MulAdd(a3, b0, c30);
c31 = hn::MulAdd(a3, b1, c31);
c32 = hn::MulAdd(a3, b2, c32);
c33 = hn::MulAdd(a3, b3, c33);
V a30, a31;
CompressTraits<MatTA>::Decompress2(
d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
c32, c33);
}
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_add, tile_c, stride_c);
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32,
c33, dd, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_c, stride_c);
StoreHorizontalSums<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
tile_c, stride_c);
}
}
@ -755,10 +664,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_add, tile_c, stride_c);
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32,
c33, dd, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
@ -861,10 +770,10 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
if constexpr (kAdd) {
const AddT* tile_add = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
c30, c31, c32, c33, tile_add, tile_c, stride_c);
const AddT* dd = add + row_b_col_c;
StoreHorizontalSumsAdd<kNumRows>(d32, c00, c01, c02, c03, c10, c11, c12,
c13, c20, c21, c22, c23, c30, c31, c32,
c33, dd, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23,
@ -906,20 +815,20 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add(
HWY_ASSERT(num_rows > 0);
switch (num_rows) {
case 1:
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
break;
case 2:
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
break;
case 3:
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
break;
default:
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(
A, B, C, add, idx_tile, kTilesX, kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
}
});
}