mirror of https://github.com/google/gemma.cpp.git
0.98x prefill: refactor in prep for cache blocking.
Slower because we now init tiles of C and accumulate into them. Also remove unused var in optimize_test and use BF16 typedef. PiperOrigin-RevId: 662115916
This commit is contained in:
parent
7316ee8f96
commit
8e028632f7
|
|
@ -15,7 +15,6 @@
|
|||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <limits>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
|
|
@ -112,7 +111,6 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
ReverseSequenceSampler training_task({
|
||||
0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1});
|
||||
size_t steps = 0;
|
||||
float prev_loss = std::numeric_limits<float>::max();
|
||||
size_t num_ok;
|
||||
for (; steps < 1000000; ++steps) {
|
||||
std::mt19937 sgen(42);
|
||||
|
|
@ -143,7 +141,6 @@ TEST(OptimizeTest, GradientDescent) {
|
|||
if (total_loss < 0.5f) {
|
||||
break;
|
||||
}
|
||||
prev_loss = total_loss;
|
||||
}
|
||||
printf("Num steps: %zu\n", steps);
|
||||
printf("Final weights:\n");
|
||||
|
|
|
|||
|
|
@ -41,8 +41,10 @@
|
|||
|
||||
namespace gcpp {
|
||||
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
|
||||
static inline const char* TypeName(float) { return "f32"; }
|
||||
static inline const char* TypeName(hwy::bfloat16_t) { return "b16"; }
|
||||
static inline const char* TypeName(BF16) { return "b16"; }
|
||||
|
||||
namespace detail {
|
||||
// How many MatT are required to store `capacity` weights. For all but
|
||||
|
|
@ -177,11 +179,11 @@ struct CompressWorkingSet {
|
|||
template <typename MatT>
|
||||
hwy::uint128_t CacheKey(const char* name) {
|
||||
// Already used/retired: s, S, n, 1
|
||||
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
|
||||
: hwy::IsSame<MatT, hwy::bfloat16_t>() ? 'B'
|
||||
: hwy::IsSame<MatT, SfpStream>() ? '$'
|
||||
: hwy::IsSame<MatT, NuqStream>() ? '2'
|
||||
: '?';
|
||||
const char prefix = hwy::IsSame<MatT, float>() ? 'F'
|
||||
: hwy::IsSame<MatT, BF16>() ? 'B'
|
||||
: hwy::IsSame<MatT, SfpStream>() ? '$'
|
||||
: hwy::IsSame<MatT, NuqStream>() ? '2'
|
||||
: '?';
|
||||
|
||||
return MakeKey((std::string(1, prefix) + name).c_str());
|
||||
}
|
||||
|
|
|
|||
520
ops/matmul-inl.h
520
ops/matmul-inl.h
|
|
@ -23,7 +23,7 @@
|
|||
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/profiler.h"
|
||||
#include "hwy/profiler.h" // temporarily disabled
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
|
||||
|
|
@ -43,107 +43,65 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
// A square kernel minimizes the ratio of loads to FMA. 4x 128-bit corresponds
|
||||
// to one cache line.
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// Initializes a reg-tile of C: if kAdd, `add[add_ofs + c]`; otherwise 0.
|
||||
// `add` has no scale, and if `kAdd` is a row vector with A.cols entries,
|
||||
// otherwise nullptr. In the latter case, adding `add_ofs` to it would be UB,
|
||||
// hence we pass it as a separate argument.
|
||||
template <size_t kNumRows, bool kAdd>
|
||||
HWY_INLINE void InitC(const float* HWY_RESTRICT add, size_t add_ofs,
|
||||
float* HWY_RESTRICT pos_c, size_t stride_c) {
|
||||
for (size_t r = 0; r < HWY_MIN(kNumRows, kRegRows); ++r) {
|
||||
for (size_t c = 0; c < kRegCols; ++c) {
|
||||
if constexpr (kAdd) {
|
||||
pos_c[r * stride_c + c] = add[add_ofs + c];
|
||||
} else {
|
||||
pos_c[r * stride_c + c] = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// c## are partial sums of the products of A and B; their horizontal sums are
|
||||
// the final matmul result, stored in C, which is always f32.
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSums(DF df, //
|
||||
VF c00, VF c01, VF c02, VF c03, //
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33, //
|
||||
float scale, float* HWY_RESTRICT tile_c,
|
||||
size_t stride_c) {
|
||||
HWY_INLINE void AddHorizontalSums(DF df, float scale, //
|
||||
VF c00, VF c01, VF c02, VF c03, //
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33, //
|
||||
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||
// expensive, but only a fraction of the A.cols/N FMAs.
|
||||
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
|
||||
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
|
||||
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
|
||||
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03);
|
||||
if (kNumRows == 1) return;
|
||||
|
||||
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10);
|
||||
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11);
|
||||
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12);
|
||||
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13);
|
||||
if (kNumRows == 2) return;
|
||||
|
||||
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20);
|
||||
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21);
|
||||
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22);
|
||||
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23);
|
||||
if (kNumRows == 3) return;
|
||||
|
||||
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30);
|
||||
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31);
|
||||
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32);
|
||||
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33);
|
||||
}
|
||||
|
||||
// As above, but also adds `add[0..3]` to columns 0..3 of `tile_c`. `add` has no
|
||||
// scale, and points to a 1D slice of the row vector.
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
|
||||
VF c00, VF c01, VF c02, VF c03, //
|
||||
VF c10, VF c11, VF c12, VF c13, //
|
||||
VF c20, VF c21, VF c22, VF c23, //
|
||||
VF c30, VF c31, VF c32, VF c33,
|
||||
const float scale,
|
||||
const float* HWY_RESTRICT add,
|
||||
float* HWY_RESTRICT tile_c,
|
||||
size_t stride_c) {
|
||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||
// expensive, but only a fraction of the A.cols/N FMAs.
|
||||
const float add0 = add[0];
|
||||
// TODO: 4x4 transpose, then 128-bit vector FMA?
|
||||
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0;
|
||||
const float add1 = add[1];
|
||||
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + add1;
|
||||
const float add2 = add[2];
|
||||
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + add2;
|
||||
const float add3 = add[3];
|
||||
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + add3;
|
||||
tile_c[stride_c * 0 + 0] += scale * hn::ReduceSum(df, c00);
|
||||
tile_c[stride_c * 0 + 1] += scale * hn::ReduceSum(df, c01);
|
||||
tile_c[stride_c * 0 + 2] += scale * hn::ReduceSum(df, c02);
|
||||
tile_c[stride_c * 0 + 3] += scale * hn::ReduceSum(df, c03);
|
||||
if (kNumRows == 1) return;
|
||||
|
||||
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + add0;
|
||||
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + add1;
|
||||
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + add2;
|
||||
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + add3;
|
||||
tile_c[stride_c * 1 + 0] += scale * hn::ReduceSum(df, c10);
|
||||
tile_c[stride_c * 1 + 1] += scale * hn::ReduceSum(df, c11);
|
||||
tile_c[stride_c * 1 + 2] += scale * hn::ReduceSum(df, c12);
|
||||
tile_c[stride_c * 1 + 3] += scale * hn::ReduceSum(df, c13);
|
||||
if (kNumRows == 2) return;
|
||||
|
||||
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + add0;
|
||||
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + add1;
|
||||
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + add2;
|
||||
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + add3;
|
||||
tile_c[stride_c * 2 + 0] += scale * hn::ReduceSum(df, c20);
|
||||
tile_c[stride_c * 2 + 1] += scale * hn::ReduceSum(df, c21);
|
||||
tile_c[stride_c * 2 + 2] += scale * hn::ReduceSum(df, c22);
|
||||
tile_c[stride_c * 2 + 3] += scale * hn::ReduceSum(df, c23);
|
||||
if (kNumRows == 3) return;
|
||||
|
||||
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + add0;
|
||||
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + add1;
|
||||
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + add2;
|
||||
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + add3;
|
||||
}
|
||||
|
||||
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
|
||||
// sites. If `!kAdd`, `add` is nullptr, so adding `add_offset` to it would be
|
||||
// UB, hence we pass it as a separate argument.
|
||||
template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
|
||||
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
|
||||
const float scale, const float* HWY_RESTRICT add, size_t add_offset,
|
||||
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
||||
if constexpr (kAdd) {
|
||||
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
scale, add + add_offset, tile_c, stride_c);
|
||||
} else {
|
||||
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
scale, tile_c, stride_c);
|
||||
}
|
||||
tile_c[stride_c * 3 + 0] += scale * hn::ReduceSum(df, c30);
|
||||
tile_c[stride_c * 3 + 1] += scale * hn::ReduceSum(df, c31);
|
||||
tile_c[stride_c * 3 + 2] += scale * hn::ReduceSum(df, c32);
|
||||
tile_c[stride_c * 3 + 3] += scale * hn::ReduceSum(df, c33);
|
||||
}
|
||||
|
||||
// Wrapper to simplify call sites. T can be const or non-const.
|
||||
|
|
@ -176,104 +134,8 @@ Mat<T> MakeMat(T* HWY_RESTRICT ptr, size_t cols) {
|
|||
return MakeMat(ptr, cols, cols);
|
||||
}
|
||||
|
||||
#undef GEMMA_NATIVE_BF16
|
||||
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
|
||||
defined(HWY_TARGET_TOGGLE))
|
||||
#define GEMMA_NATIVE_BF16 1
|
||||
#else
|
||||
#define GEMMA_NATIVE_BF16 0
|
||||
#endif
|
||||
|
||||
#if GEMMA_NATIVE_BF16
|
||||
|
||||
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
|
||||
template <size_t kNumRows, bool kAdd>
|
||||
HWY_INLINE void MatMulTile(const Mat<const hwy::bfloat16_t>& A,
|
||||
const Mat<const hwy::bfloat16_t>& B,
|
||||
const size_t row_a, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const Mat<float>& C) {
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
||||
// bf16 vectors.
|
||||
const hn::Repartition<hwy::bfloat16_t, decltype(df)> d;
|
||||
const size_t N = Lanes(d);
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
VF c00 = hn::Zero(df);
|
||||
VF c01 = hn::Zero(df);
|
||||
VF c02 = hn::Zero(df);
|
||||
VF c03 = hn::Zero(df);
|
||||
|
||||
VF c10 = hn::Zero(df);
|
||||
VF c11 = hn::Zero(df);
|
||||
VF c12 = hn::Zero(df);
|
||||
VF c13 = hn::Zero(df);
|
||||
|
||||
VF c20 = hn::Zero(df);
|
||||
VF c21 = hn::Zero(df);
|
||||
VF c22 = hn::Zero(df);
|
||||
VF c23 = hn::Zero(df);
|
||||
|
||||
VF c30 = hn::Zero(df);
|
||||
VF c31 = hn::Zero(df);
|
||||
VF c32 = hn::Zero(df);
|
||||
VF c33 = hn::Zero(df);
|
||||
|
||||
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A.ptr + A.Row(row_a);
|
||||
const hwy::bfloat16_t* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < A.cols; col_ab += N) {
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const V b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
|
||||
|
||||
const V a0 = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
|
||||
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
|
||||
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
||||
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
||||
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const V a1 = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
|
||||
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
||||
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
||||
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
||||
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const V a2 = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
|
||||
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
||||
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
||||
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
||||
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const V a3 = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
|
||||
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
||||
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
|
||||
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
|
||||
c33 = hn::ReorderWidenMulAccumulate(df, a3, b3, c33, unused_sum1);
|
||||
}
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
|
||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||
c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
|
||||
}
|
||||
|
||||
#endif // GEMMA_NATIVE_BF16
|
||||
|
||||
// The col_ab loop is unrolled 2x, so we have two consecutive a0/a1 and b00/b01
|
||||
// etc. Multiplies a[c] with b[r,c] and adds to c[r].
|
||||
// Inner loop of the kernel, called once per kRegRows. c[r] += a[c] * b[r,c].
|
||||
// The col_ab loop is unrolled 2x, so we have a0/a1 and b00/b01 etc.
|
||||
template <class VF>
|
||||
HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||
const VF& b01, const VF& b10, const VF& b11,
|
||||
|
|
@ -289,12 +151,153 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
|||
c3 = hn::MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
// Special case for the first iteration: c## are zero, so skip the first add.
|
||||
template <class VF>
|
||||
HWY_INLINE void FirstTileRow(const VF& a0, const VF& a1, const VF& b00,
|
||||
const VF& b01, const VF& b10, const VF& b11,
|
||||
const VF& b20, const VF& b21, const VF& b30,
|
||||
const VF& b31, VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = hn::Mul(a0, b00);
|
||||
c1 = hn::Mul(a0, b10);
|
||||
c2 = hn::Mul(a0, b20);
|
||||
c3 = hn::Mul(a0, b30);
|
||||
c0 = hn::MulAdd(a1, b01, c0);
|
||||
c1 = hn::MulAdd(a1, b11, c1);
|
||||
c2 = hn::MulAdd(a1, b21, c2);
|
||||
c3 = hn::MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
#undef GEMMA_NATIVE_BF16
|
||||
#if HWY_IDE || (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
|
||||
defined(HWY_TARGET_TOGGLE))
|
||||
#define GEMMA_NATIVE_BF16 1
|
||||
#else
|
||||
#define GEMMA_NATIVE_BF16 0
|
||||
#endif
|
||||
|
||||
#if GEMMA_NATIVE_BF16
|
||||
|
||||
// Specializations for f32 += bf16 * bf16 that avoid promoting to f32.
|
||||
|
||||
// Inner loop as above, but not unrolled. c[r] += a * b[r].
|
||||
template <class DF, class VF = hn::Vec<DF>,
|
||||
class VBF16 = hn::Vec<hn::Repartition<BF16, DF>>>
|
||||
HWY_INLINE void UpdateTileRow(DF df, const VBF16& a, const VBF16& b0,
|
||||
const VBF16& b1, const VBF16& b2, const VBF16& b3,
|
||||
VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
DF df;
|
||||
VF unused_sum1 = hn::Zero(df);
|
||||
c0 = hn::ReorderWidenMulAccumulate(df, a, b0, c0, unused_sum1);
|
||||
c1 = hn::ReorderWidenMulAccumulate(df, a, b1, c1, unused_sum1);
|
||||
c2 = hn::ReorderWidenMulAccumulate(df, a, b2, c2, unused_sum1);
|
||||
c3 = hn::ReorderWidenMulAccumulate(df, a, b3, c3, unused_sum1);
|
||||
|
||||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
}
|
||||
|
||||
// Special case for the first iteration: c## are zero, so skip the first add.
|
||||
template <class DF, class VF = hn::Vec<DF>,
|
||||
class VBF16 = hn::Vec<hn::Repartition<BF16, DF>>>
|
||||
HWY_INLINE void FirstTileRow(DF df, const VBF16& a, const VBF16& b0,
|
||||
const VBF16& b1, const VBF16& b2, const VBF16& b3,
|
||||
VF& c0, VF& c1, VF& c2, VF& c3) {
|
||||
c0 = hn::WidenMulPairwiseAdd(df, a, b0);
|
||||
c1 = hn::WidenMulPairwiseAdd(df, a, b1);
|
||||
c2 = hn::WidenMulPairwiseAdd(df, a, b2);
|
||||
c3 = hn::WidenMulPairwiseAdd(df, a, b3);
|
||||
}
|
||||
|
||||
template <size_t kNumRows, bool kAdd>
|
||||
HWY_INLINE void MatMulTile(const Mat<const BF16>& A, const Mat<const BF16>& B,
|
||||
const size_t row_ac, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const Mat<float>& C) {
|
||||
const hn::ScalableTag<float> df;
|
||||
using VF = hn::Vec<decltype(df)>;
|
||||
// ReorderWidenMulAccumulate does not use its sum1 arg and we can use full
|
||||
// bf16 vectors.
|
||||
const hn::Repartition<BF16, decltype(df)> d;
|
||||
const size_t N = Lanes(d);
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
V b0, b1, b2, b3; // one from each row
|
||||
VF c00, c01, c02, c03;
|
||||
VF c10, c11, c12, c13;
|
||||
VF c20, c21, c22, c23;
|
||||
VF c30, c31, c32, c33;
|
||||
|
||||
const BF16* HWY_RESTRICT A_tile = A.ptr + A.Row(row_ac);
|
||||
const BF16* HWY_RESTRICT B_tile = B.ptr + B.Row(row_b_col_c);
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
|
||||
|
||||
size_t col_ab = 0;
|
||||
|
||||
// First iteration initializes the c## vectors.
|
||||
{
|
||||
b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
|
||||
b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
|
||||
b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
|
||||
b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
|
||||
|
||||
{
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows == 3) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
|
||||
FirstTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (col_ab += N; col_ab < A.cols; col_ab += N) {
|
||||
b0 = hn::LoadU(d, B_tile + B.stride * 0 + col_ab);
|
||||
b1 = hn::LoadU(d, B_tile + B.stride * 1 + col_ab);
|
||||
b2 = hn::LoadU(d, B_tile + B.stride * 2 + col_ab);
|
||||
b3 = hn::LoadU(d, B_tile + B.stride * 3 + col_ab);
|
||||
|
||||
{
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 0 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c00, c01, c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 1 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c10, c11, c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 2 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c20, c21, c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows == 3) {
|
||||
const V a = hn::LoadU(d, A_tile + A.stride * 3 + col_ab);
|
||||
UpdateTileRow(df, a, b0, b1, b2, b3, c30, c31, c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
AddHorizontalSums<kNumRows>(df, scale, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||
c20, c21, c22, c23, c30, c31, c32, c33, C_tile,
|
||||
C.stride);
|
||||
}
|
||||
|
||||
#endif // GEMMA_NATIVE_BF16
|
||||
|
||||
// Streams a `(kNumRows, 4)` strip of `A` and the transposed `B`, then writes a
|
||||
// finished tile of `C`.
|
||||
// General case: uses CompressTraits to load from A and B.
|
||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
|
||||
const size_t row_a, const size_t row_b_col_c,
|
||||
const size_t row_ac, const size_t row_b_col_c,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const Mat<float>& C) {
|
||||
using TraitsA = CompressTraits<hwy::RemoveConst<MatTA>>;
|
||||
|
|
@ -303,74 +306,92 @@ HWY_INLINE void MatMulTile(const Mat<MatTA>& A, const Mat<MatTB>& B,
|
|||
const hn::ScalableTag<float> d32;
|
||||
const size_t N = hn::Lanes(d32);
|
||||
using V = hn::Vec<decltype(d32)>;
|
||||
V c00 = hn::Zero(d32);
|
||||
V c01 = hn::Zero(d32);
|
||||
V c02 = hn::Zero(d32);
|
||||
V c03 = hn::Zero(d32);
|
||||
V b00, b01, b10, b11, b20, b21, b30, b31; // two from each row
|
||||
V c00, c01, c02, c03;
|
||||
V c10, c11, c12, c13;
|
||||
V c20, c21, c22, c23;
|
||||
V c30, c31, c32, c33;
|
||||
|
||||
V c10 = hn::Zero(d32);
|
||||
V c11 = hn::Zero(d32);
|
||||
V c12 = hn::Zero(d32);
|
||||
V c13 = hn::Zero(d32);
|
||||
|
||||
V c20 = hn::Zero(d32);
|
||||
V c21 = hn::Zero(d32);
|
||||
V c22 = hn::Zero(d32);
|
||||
V c23 = hn::Zero(d32);
|
||||
|
||||
V c30 = hn::Zero(d32);
|
||||
V c31 = hn::Zero(d32);
|
||||
V c32 = hn::Zero(d32);
|
||||
V c33 = hn::Zero(d32);
|
||||
|
||||
const size_t A_ofs = A.Row(row_a);
|
||||
const size_t A_ofs = A.Row(row_ac);
|
||||
const size_t B_ofs = B.Row(row_b_col_c);
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_ac) + row_b_col_c;
|
||||
InitC<kNumRows, kAdd>(add, row_b_col_c, C_tile, C.stride);
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of 2*N
|
||||
// (since we are decoding consecutive bytes at each iteration).
|
||||
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c,
|
||||
// col_ab) for B. Accumulates into the c## vectors.
|
||||
// Top-left of tile is (row_ac, col_ab) for A, and (row_b_col_c,
|
||||
// col_ab) for B. First iteration initializes the c## vectors.
|
||||
size_t col_ab = 0;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
|
||||
V b00, b01;
|
||||
{
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
|
||||
V b10, b11;
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
|
||||
V b20, b21;
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
|
||||
V b30, b31;
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
|
||||
|
||||
V a00, a01;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a00, a01);
|
||||
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
V a10, a11;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a10, a11);
|
||||
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
V a20, a21;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a20, a21);
|
||||
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
V a30, a31;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a30, a31);
|
||||
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
{
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows > 3) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1);
|
||||
FirstTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT C_tile = C.ptr + C.Row(row_a) + row_b_col_c;
|
||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||
c32, c33, scale, add, row_b_col_c, C_tile, C.stride);
|
||||
// Main loop: accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (col_ab += 2 * N; col_ab <= A.cols - 2 * N; col_ab += 2 * N) {
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 0 + col_ab, b00, b01);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 1 + col_ab, b10, b11);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 2 + col_ab, b20, b21);
|
||||
TraitsB::Decompress2(d32, B.ptr, B_ofs + B.stride * 3 + col_ab, b30, b31);
|
||||
|
||||
{
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 0 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
}
|
||||
if constexpr (kNumRows > 1) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 1 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
}
|
||||
if constexpr (kNumRows > 2) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 2 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
}
|
||||
if constexpr (kNumRows > 3) {
|
||||
V a0, a1;
|
||||
TraitsA::Decompress2(d32, A.ptr, A_ofs + A.stride * 3 + col_ab, a0, a1);
|
||||
UpdateTileRow(a0, a1, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
}
|
||||
|
||||
AddHorizontalSums<kNumRows>(d32, scale, c00, c01, c02, c03, c10, c11, c12,
|
||||
c13, c20, c21, c22, c23, c30, c31, c32, c33,
|
||||
C_tile, C.stride);
|
||||
}
|
||||
|
||||
// Computes the matrix product `A * B * scale [+ add]` and stores it in `C`.
|
||||
|
|
@ -395,10 +416,7 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
|
|||
const Mat<MatTB>& B, const float scale,
|
||||
const float* HWY_RESTRICT add, const Mat<float>& C,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Matmul");
|
||||
constexpr size_t kRegRows = 4; // if changing, also update the switch below.
|
||||
constexpr size_t kRegCols = 4;
|
||||
|
||||
// PROFILER_ZONE("Matmul");
|
||||
HWY_DASSERT(A.NotEmpty() && B.NotEmpty() && C.NotEmpty());
|
||||
HWY_DASSERT(A.cols == B.cols);
|
||||
|
||||
|
|
@ -417,24 +435,24 @@ HWY_NOINLINE void MatMul_4x4(const size_t batch_size, const Mat<MatTA>& A,
|
|||
[&](const uint64_t idx_tile, size_t /*thread*/) HWY_ATTR {
|
||||
const size_t tx = idx_tile % tilesX;
|
||||
const size_t ty = idx_tile / tilesX;
|
||||
const size_t row_a = ty * kRegRows;
|
||||
const size_t row_ac = ty * kRegRows;
|
||||
const size_t row_b_col_c = tx * kRegCols;
|
||||
// How many rows of C are left to compute. If more than 4, this
|
||||
// tile still only computes 4 rows.
|
||||
const size_t num_rows = batch_size - row_a;
|
||||
const size_t num_rows = batch_size - row_ac;
|
||||
HWY_DASSERT(num_rows != 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
MatMulTile<1, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||
MatMulTile<1, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
break;
|
||||
case 2:
|
||||
MatMulTile<2, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||
MatMulTile<2, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
break;
|
||||
case 3:
|
||||
MatMulTile<3, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||
MatMulTile<3, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
break;
|
||||
default:
|
||||
MatMulTile<4, kAdd>(A, B, row_a, row_b_col_c, scale, add, C);
|
||||
MatMulTile<4, kAdd>(A, B, row_ac, row_b_col_c, scale, add, C);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
|
|
@ -269,7 +269,6 @@ void TestAllMatMul() {
|
|||
}
|
||||
|
||||
hwy::ThreadPool pool(4);
|
||||
using BF16 = hwy::bfloat16_t;
|
||||
using F32 = float;
|
||||
using SFP = SfpStream;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue