mirror of https://github.com/google/gemma.cpp.git
Implement float * SfpStream matmul by decompressing 4 * kColsA_RowsB -sized chunks of the second matrix.
PiperOrigin-RevId: 642533996
This commit is contained in:
parent
9c869c4655
commit
f467670de7
100
gemma/ops.h
100
gemma/ops.h
|
|
@ -27,6 +27,7 @@
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
#include "hwy/detect_targets.h"
|
#include "hwy/detect_targets.h"
|
||||||
|
|
@ -344,6 +345,93 @@ HWY_INLINE void GEMM_4x4_Tile(const MatT* HWY_RESTRICT A,
|
||||||
c23, c30, c31, c32, c33, tile_c, stride_c);
|
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// As above, for SfpStream.
|
||||||
|
template <size_t kColsA_RowsB>
|
||||||
|
HWY_INLINE void GEMM_4x4_Tile(const float* HWY_RESTRICT A,
|
||||||
|
const SfpStream* HWY_RESTRICT B,
|
||||||
|
float* HWY_RESTRICT C, const size_t idx_tile,
|
||||||
|
const size_t xtiles, const size_t stride_a,
|
||||||
|
const size_t stride_b, const size_t stride_c) {
|
||||||
|
constexpr size_t kRegRows = 4;
|
||||||
|
constexpr size_t kRegCols = 4;
|
||||||
|
|
||||||
|
// Top-left of tile is (row_a, col_ab) for A, and (row_b_col_c, col_ab) for B.
|
||||||
|
const size_t row_a = idx_tile / xtiles * kRegRows;
|
||||||
|
const size_t row_b_col_c = idx_tile % xtiles * kRegCols;
|
||||||
|
|
||||||
|
const hn::ScalableTag<float> d;
|
||||||
|
const size_t N = Lanes(d);
|
||||||
|
using V = hn::Vec<decltype(d)>;
|
||||||
|
|
||||||
|
V c00 = hn::Zero(d);
|
||||||
|
V c01 = hn::Zero(d);
|
||||||
|
V c02 = hn::Zero(d);
|
||||||
|
V c03 = hn::Zero(d);
|
||||||
|
|
||||||
|
V c10 = hn::Zero(d);
|
||||||
|
V c11 = hn::Zero(d);
|
||||||
|
V c12 = hn::Zero(d);
|
||||||
|
V c13 = hn::Zero(d);
|
||||||
|
|
||||||
|
V c20 = hn::Zero(d);
|
||||||
|
V c21 = hn::Zero(d);
|
||||||
|
V c22 = hn::Zero(d);
|
||||||
|
V c23 = hn::Zero(d);
|
||||||
|
|
||||||
|
V c30 = hn::Zero(d);
|
||||||
|
V c31 = hn::Zero(d);
|
||||||
|
V c32 = hn::Zero(d);
|
||||||
|
V c33 = hn::Zero(d);
|
||||||
|
|
||||||
|
const float* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||||
|
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> tile_b_unique_ptr =
|
||||||
|
hwy::AllocateAligned<float>(kRegRows * kColsA_RowsB);
|
||||||
|
CompressTraits<SfpStream>::Decompress(
|
||||||
|
d,
|
||||||
|
/*in_capacity=*/0, B, stride_b * row_b_col_c, tile_b_unique_ptr.get(),
|
||||||
|
kRegRows * kColsA_RowsB);
|
||||||
|
const float* HWY_RESTRICT tile_b = tile_b_unique_ptr.get();
|
||||||
|
|
||||||
|
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||||
|
// Accumulates into the c## vectors.
|
||||||
|
HWY_UNROLL(1)
|
||||||
|
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||||
|
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||||
|
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||||
|
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
|
||||||
|
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
|
||||||
|
|
||||||
|
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
|
||||||
|
c00 = hn::MulAdd(a0, b0, c00);
|
||||||
|
c01 = hn::MulAdd(a0, b1, c01);
|
||||||
|
c02 = hn::MulAdd(a0, b2, c02);
|
||||||
|
c03 = hn::MulAdd(a0, b3, c03);
|
||||||
|
|
||||||
|
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||||
|
c10 = hn::MulAdd(a1, b0, c10);
|
||||||
|
c11 = hn::MulAdd(a1, b1, c11);
|
||||||
|
c12 = hn::MulAdd(a1, b2, c12);
|
||||||
|
c13 = hn::MulAdd(a1, b3, c13);
|
||||||
|
|
||||||
|
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||||
|
c20 = hn::MulAdd(a2, b0, c20);
|
||||||
|
c21 = hn::MulAdd(a2, b1, c21);
|
||||||
|
c22 = hn::MulAdd(a2, b2, c22);
|
||||||
|
c23 = hn::MulAdd(a2, b3, c23);
|
||||||
|
|
||||||
|
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||||
|
c30 = hn::MulAdd(a3, b0, c30);
|
||||||
|
c31 = hn::MulAdd(a3, b1, c31);
|
||||||
|
c32 = hn::MulAdd(a3, b2, c32);
|
||||||
|
c33 = hn::MulAdd(a3, b3, c33);
|
||||||
|
}
|
||||||
|
|
||||||
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
|
StoreHorizontalSums(d, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22,
|
||||||
|
c23, c30, c31, c32, c33, tile_c, stride_c);
|
||||||
|
}
|
||||||
|
|
||||||
// Same as above, but with mixed Mat types.
|
// Same as above, but with mixed Mat types.
|
||||||
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA),
|
template <size_t kColsA_RowsB, typename MatTA, HWY_IF_F32(MatTA),
|
||||||
typename MatTB, HWY_IF_BF16(MatTB)>
|
typename MatTB, HWY_IF_BF16(MatTB)>
|
||||||
|
|
@ -518,6 +606,18 @@ HWY_INLINE void MatMulSlow(const MatTA* HWY_RESTRICT a,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <size_t kM, size_t kN, size_t kK>
|
||||||
|
HWY_INLINE void MatMulSlow(const float* HWY_RESTRICT a,
|
||||||
|
const SfpStream* HWY_RESTRICT b_sfp_stream,
|
||||||
|
float* HWY_RESTRICT out) {
|
||||||
|
const hn::ScalableTag<float> d;
|
||||||
|
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||||
|
CompressTraits<SfpStream>::Decompress(d,
|
||||||
|
/*in_capacity=*/0, b_sfp_stream, 0,
|
||||||
|
b.get(), kK * kN);
|
||||||
|
MatMulSlow<kM, kN, kK>(a, b.get(), out);
|
||||||
|
}
|
||||||
|
|
||||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||||
const size_t size, float* HWY_RESTRICT out) {
|
const size_t size, float* HWY_RESTRICT out) {
|
||||||
const hn::ScalableTag<float> df;
|
const hn::ScalableTag<float> df;
|
||||||
|
|
|
||||||
|
|
@ -293,8 +293,8 @@ struct TestSoftmax {
|
||||||
for (size_t i = 0; i < count; ++i) {
|
for (size_t i = 0; i < count; ++i) {
|
||||||
sum += x[i];
|
sum += x[i];
|
||||||
double rel = std::abs(x[i] - e[i]) / e[i];
|
double rel = std::abs(x[i] - e[i]) / e[i];
|
||||||
ASSERT_LT(rel, 1e-6)
|
ASSERT_LT(rel, 1e-6) << "Mismatch on coordinate " << i << " out of "
|
||||||
<< "Mismatch on coordinate " << i << " out of " << count;
|
<< count;
|
||||||
}
|
}
|
||||||
ASSERT_NEAR(sum, 1.0, 1e-6);
|
ASSERT_NEAR(sum, 1.0, 1e-6);
|
||||||
}
|
}
|
||||||
|
|
@ -388,7 +388,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
|
||||||
new CompressedArray<MatT, kOuter * kInner>);
|
new CompressedArray<MatT, kOuter * kInner>);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
const float scale = 1.0f / kInner;
|
const float scale = 1.875f / (kInner * kOuter + offset);
|
||||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||||
for (size_t j = 0; j < kInner; j++) {
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
content[i * kInner + j] =
|
content[i * kInner + j] =
|
||||||
|
|
@ -411,7 +411,7 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||||
new CompressedArray<MatT, kOuter * kInner>);
|
new CompressedArray<MatT, kOuter * kInner>);
|
||||||
hwy::AlignedFreeUniquePtr<float[]> content =
|
hwy::AlignedFreeUniquePtr<float[]> content =
|
||||||
hwy::AllocateAligned<float>(kOuter * kInner);
|
hwy::AllocateAligned<float>(kOuter * kInner);
|
||||||
const float scale = 1.0f / kInner;
|
const float scale = 1.875f / (kInner * kOuter + offset);
|
||||||
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
pool.Run(0, kOuter, [&](const size_t i, size_t /*thread*/) {
|
||||||
for (size_t j = 0; j < kInner; j++) {
|
for (size_t j = 0; j < kInner; j++) {
|
||||||
content[j * kOuter + i] =
|
content[j * kOuter + i] =
|
||||||
|
|
@ -554,17 +554,17 @@ void TestAllTiledMatMul() {
|
||||||
TestTiledMatMul<512, 512, 512, float>();
|
TestTiledMatMul<512, 512, 512, float>();
|
||||||
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
TestTiledMatMul<512, 512, 512, hwy::bfloat16_t>();
|
||||||
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
TestTiledMatMul<512, 512, 512, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledMatMul<512, 512, 512, float, SfpStream>();
|
||||||
|
|
||||||
// minimal non-square test
|
// minimal non-square test
|
||||||
TestTiledMatMul<4, 128, 4, float>();
|
TestTiledMatMul<4, 128, 4, float>();
|
||||||
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
TestTiledMatMul<4, 128, 4, hwy::bfloat16_t>();
|
||||||
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
TestTiledMatMul<4, 128, 4, float, hwy::bfloat16_t>();
|
||||||
|
TestTiledMatMul<32, 128, 32, float, SfpStream>();
|
||||||
|
|
||||||
// large-scale test
|
// large-scale test
|
||||||
// TODO(philculliton): investigate rounding issues with large matrices
|
// TODO(philculliton): investigate rounding issues with large matrices
|
||||||
TestTiledMatMul<512, 24576, 3072, float>();
|
TestTiledMatMul<512, 24576, 3072, float>();
|
||||||
|
|
||||||
// TODO(janwas): SFP
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void TestMatVecAdd() {
|
void TestMatVecAdd() {
|
||||||
|
|
@ -666,6 +666,7 @@ HWY_AFTER_NAMESPACE();
|
||||||
|
|
||||||
namespace gcpp {
|
namespace gcpp {
|
||||||
HWY_BEFORE_TEST(OpsTest);
|
HWY_BEFORE_TEST(OpsTest);
|
||||||
|
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllAddFrom);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulBy);
|
||||||
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
HWY_EXPORT_AND_TEST_P(OpsTest, TestAllMulByConst);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue