mirror of https://github.com/google/gemma.cpp.git
Add offset arg to MatMul, rename, Matmul for logits = ~1.1x decode speedup
PiperOrigin-RevId: 657167257
This commit is contained in:
parent
aaf51898b6
commit
2721f54446
|
|
@ -238,9 +238,10 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
// Compute Q only or QKV (if MHA).
|
||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||
const float scale = layer_weights->qkv_einsum_w.scale();
|
||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||
num_interleaved, activations.pre_att_rms_out.All(),
|
||||
layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
|
||||
MatMul_4x4<kModelDim, kHeads * kQStride, /*kAdd=*/false>(
|
||||
num_interleaved, activations.pre_att_rms_out.All(), 0,
|
||||
layer_weights->qkv_einsum_w.data(), 0, scale, activations.q.All(),
|
||||
/*add=*/nullptr, pool);
|
||||
|
||||
// Compute KV if not MHA.
|
||||
if constexpr (!kIsMHA) {
|
||||
|
|
@ -256,7 +257,7 @@ HWY_NOINLINE void GemmaAttention(size_t interleaved_start, size_t num_tokens,
|
|||
cache_pos * kCachePosSize + layer * kCacheLayerSize;
|
||||
float* HWY_RESTRICT kv = kv_cache.kv_cache.get() + kv_offset;
|
||||
// KV structure is [k, v, k, v, ....] = kKVHeads pairs of (k, v).
|
||||
// TODO: requires MatMul support for offsets.
|
||||
// TODO: requires batched KVCache support.
|
||||
MatVec<kKVHeads * 2 * kQKVDim, kModelDim>(
|
||||
layer_weights->qkv_einsum_w, kHeads * kQKVDim * kModelDim, x,
|
||||
activations.even_odd.All(), kv, pool);
|
||||
|
|
@ -431,7 +432,6 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
const auto A = activations.bf_pre_ffw_rms_out.All();
|
||||
const float scale = layer_weights->gating_einsum_w.scale();
|
||||
const auto B1 = layer_weights->gating_einsum_w.data();
|
||||
const auto B2 = B1 + kColsA * kColsB;
|
||||
auto C1 = activations.C1.All();
|
||||
auto C2 = activations.C2.All();
|
||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||
|
|
@ -444,21 +444,23 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_interleaved,
|
|||
output_bias = layer_weights->ffw_output_biases.data_scale1();
|
||||
}
|
||||
|
||||
const size_t A_ofs = 0; // no offset, using the same activations for both.
|
||||
// Will go through GELU.
|
||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B1, scale,
|
||||
C1, bias1, pool);
|
||||
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
|
||||
/*B_ofs=*/0, scale, C1, bias1, pool);
|
||||
// What to multiply by.
|
||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_interleaved, A, B2, scale,
|
||||
C2, bias2, pool);
|
||||
MatMul_4x4<kColsA, kColsB, kAddBias>(num_interleaved, A, A_ofs, B1,
|
||||
/*B_ofs=*/kColsA * kColsB, scale, C2,
|
||||
bias2, pool);
|
||||
|
||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||
Activation<TConfig>(C1, C2, kFFHiddenDim * num_interleaved);
|
||||
|
||||
// Hidden layer -> output layer.
|
||||
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
||||
num_interleaved, C1, layer_weights->linear_w.data(),
|
||||
layer_weights->linear_w.scale(), activations.ffw_out.All(),
|
||||
output_bias, pool);
|
||||
MatMul_4x4<kFFHiddenDim, kModelDim, kAddBias>(
|
||||
num_interleaved, C1, 0, layer_weights->linear_w.data(), 0,
|
||||
layer_weights->linear_w.scale(), activations.ffw_out.All(), output_bias,
|
||||
pool);
|
||||
}
|
||||
|
||||
// `batch_idx` indicates which row of `x` to write to.
|
||||
|
|
@ -1003,12 +1005,14 @@ void GenerateT(const ByteStorageT& weights_u8, Activations& activations,
|
|||
|
||||
bool all_queries_eos = true;
|
||||
PROFILER_ZONE("Gen.Embedding");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul_4x4<TConfig::kModelDim, kVocabSize, /*kAdd=*/false>(
|
||||
num_queries, activations.x.All(), 0,
|
||||
weights.embedder_input_embedding.data(), 0,
|
||||
weights.embedder_input_embedding.scale(), activations.logits.All(),
|
||||
/*add=*/nullptr, pool);
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
// Compute logits from last layer activations. TODO: MatMul
|
||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||
weights.embedder_input_embedding, 0, activations.x.Batch(query_idx),
|
||||
activations.even_odd.All(), logits, pool);
|
||||
if constexpr (TConfig::kFinalCap > 0.0f) {
|
||||
LogitsSoftCap(TConfig::kFinalCap, logits, kVocabSize);
|
||||
}
|
||||
|
|
|
|||
237
ops/matmul-inl.h
237
ops/matmul-inl.h
|
|
@ -17,20 +17,14 @@
|
|||
#ifndef THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
#define THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
|
||||
#include <math.h>
|
||||
#include <stddef.h>
|
||||
#include <stdint.h>
|
||||
#include <stdio.h>
|
||||
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <type_traits> // std::enable_if_t
|
||||
|
||||
#include "compression/compress.h"
|
||||
#include "compression/sfp.h"
|
||||
#include "hwy/base.h"
|
||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||
#include "hwy/detect_targets.h"
|
||||
#include "hwy/profiler.h"
|
||||
|
||||
#endif // THIRD_PARTY_GEMMA_CPP_OPS_MATMUL_INL_H_
|
||||
|
|
@ -53,23 +47,8 @@ namespace gcpp {
|
|||
namespace HWY_NAMESPACE {
|
||||
namespace hn = hwy::HWY_NAMESPACE;
|
||||
|
||||
HWY_INLINE constexpr size_t MaxCols() {
|
||||
// Vec + mat rows should fit into 32 KiB L1.
|
||||
return 2048;
|
||||
}
|
||||
|
||||
template <size_t kOuter>
|
||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||
// vector; prefer a power of two for faster division.
|
||||
constexpr size_t kLanes = hn::ScalableTag<float>().MaxLanes();
|
||||
constexpr size_t kRowsPerStrip =
|
||||
kOuter < 128 ? kLanes
|
||||
: HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(kOuter / 128));
|
||||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
// Shared between f32 and bf16, which also accumulates into f32 vectors.
|
||||
// c## are partial sums of the products of A and B; their horizontal sums are
|
||||
// the final matmul result, stored in C, which is always f32.
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSums(DF df, //
|
||||
VF c00, VF c01, VF c02, VF c03, //
|
||||
|
|
@ -106,7 +85,8 @@ HWY_INLINE void StoreHorizontalSums(DF df, //
|
|||
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33);
|
||||
}
|
||||
|
||||
// Completes the tile by summing across the vectors, and adds the biases.
|
||||
// As above, but also adds `add[0..3]` to columns 0..3 of `tile_c`. `add` has no
|
||||
// scale, and points to a 1D slice of the row vector.
|
||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||
HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
|
||||
VF c00, VF c01, VF c02, VF c03, //
|
||||
|
|
@ -121,32 +101,33 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
|
|||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
||||
float addon0 = hwy::ConvertScalarTo<float>(add[0]);
|
||||
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + addon0;
|
||||
float addon1 = hwy::ConvertScalarTo<float>(add[1]);
|
||||
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + addon1;
|
||||
float addon2 = hwy::ConvertScalarTo<float>(add[2]);
|
||||
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + addon2;
|
||||
float addon3 = hwy::ConvertScalarTo<float>(add[3]);
|
||||
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + addon3;
|
||||
const float add0 = add[0];
|
||||
// TODO: 4x4 transpose, then 128-bit vector FMA?
|
||||
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + add0;
|
||||
const float add1 = add[1];
|
||||
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + add1;
|
||||
const float add2 = add[2];
|
||||
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + add2;
|
||||
const float add3 = add[3];
|
||||
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + add3;
|
||||
if (kNumRows == 1) return;
|
||||
|
||||
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + addon0;
|
||||
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + addon1;
|
||||
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + addon2;
|
||||
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + addon3;
|
||||
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + add0;
|
||||
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + add1;
|
||||
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + add2;
|
||||
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + add3;
|
||||
if (kNumRows == 2) return;
|
||||
|
||||
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + addon0;
|
||||
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + addon1;
|
||||
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + addon2;
|
||||
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + addon3;
|
||||
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + add0;
|
||||
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + add1;
|
||||
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + add2;
|
||||
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + add3;
|
||||
if (kNumRows == 3) return;
|
||||
|
||||
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + addon0;
|
||||
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + addon1;
|
||||
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + addon2;
|
||||
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + addon3;
|
||||
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + add0;
|
||||
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + add1;
|
||||
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + add2;
|
||||
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + add3;
|
||||
}
|
||||
|
||||
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
|
||||
|
|
@ -180,15 +161,15 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
|||
#if GEMMA_NATIVE_BF16
|
||||
|
||||
// Specialization for f32 += bf16 * bf16 that avoids promoting to f32.
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
|
||||
template <size_t kNumRows, bool kAdd>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
||||
const size_t A_ofs,
|
||||
const hwy::bfloat16_t* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C,
|
||||
const float scale,
|
||||
const float* HWY_RESTRICT add,
|
||||
const size_t B_ofs, float* HWY_RESTRICT C,
|
||||
const float scale, const float* HWY_RESTRICT add,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
const size_t cols_a, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
|
@ -226,41 +207,42 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
|||
VF c32 = hn::Zero(df);
|
||||
VF c33 = hn::Zero(df);
|
||||
|
||||
const hwy::bfloat16_t* HWY_RESTRICT tile_a = A + stride_a * row_a;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT tile_b = B + stride_b * row_b_col_c;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT A_tile = A + A_ofs + stride_a * row_a;
|
||||
const hwy::bfloat16_t* HWY_RESTRICT B_tile =
|
||||
B + B_ofs + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of N.
|
||||
// Accumulates into the c## vectors.
|
||||
HWY_UNROLL(1)
|
||||
for (size_t col_ab = 0; col_ab < kColsA_RowsB; col_ab += N) {
|
||||
for (size_t col_ab = 0; col_ab < cols_a; col_ab += N) {
|
||||
using V = hn::Vec<decltype(d)>;
|
||||
const V b0 = hn::LoadU(d, tile_b + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, tile_b + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, tile_b + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, tile_b + stride_b * 3 + col_ab);
|
||||
const V b0 = hn::LoadU(d, B_tile + stride_b * 0 + col_ab);
|
||||
const V b1 = hn::LoadU(d, B_tile + stride_b * 1 + col_ab);
|
||||
const V b2 = hn::LoadU(d, B_tile + stride_b * 2 + col_ab);
|
||||
const V b3 = hn::LoadU(d, B_tile + stride_b * 3 + col_ab);
|
||||
|
||||
const V a0 = hn::LoadU(d, tile_a + stride_a * 0 + col_ab);
|
||||
const V a0 = hn::LoadU(d, A_tile + stride_a * 0 + col_ab);
|
||||
c00 = hn::ReorderWidenMulAccumulate(df, a0, b0, c00, unused_sum1);
|
||||
c01 = hn::ReorderWidenMulAccumulate(df, a0, b1, c01, unused_sum1);
|
||||
c02 = hn::ReorderWidenMulAccumulate(df, a0, b2, c02, unused_sum1);
|
||||
c03 = hn::ReorderWidenMulAccumulate(df, a0, b3, c03, unused_sum1);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
const V a1 = hn::LoadU(d, tile_a + stride_a * 1 + col_ab);
|
||||
const V a1 = hn::LoadU(d, A_tile + stride_a * 1 + col_ab);
|
||||
c10 = hn::ReorderWidenMulAccumulate(df, a1, b0, c10, unused_sum1);
|
||||
c11 = hn::ReorderWidenMulAccumulate(df, a1, b1, c11, unused_sum1);
|
||||
c12 = hn::ReorderWidenMulAccumulate(df, a1, b2, c12, unused_sum1);
|
||||
c13 = hn::ReorderWidenMulAccumulate(df, a1, b3, c13, unused_sum1);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
const V a2 = hn::LoadU(d, tile_a + stride_a * 2 + col_ab);
|
||||
const V a2 = hn::LoadU(d, A_tile + stride_a * 2 + col_ab);
|
||||
c20 = hn::ReorderWidenMulAccumulate(df, a2, b0, c20, unused_sum1);
|
||||
c21 = hn::ReorderWidenMulAccumulate(df, a2, b1, c21, unused_sum1);
|
||||
c22 = hn::ReorderWidenMulAccumulate(df, a2, b2, c22, unused_sum1);
|
||||
c23 = hn::ReorderWidenMulAccumulate(df, a2, b3, c23, unused_sum1);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
const V a3 = hn::LoadU(d, tile_a + stride_a * 3 + col_ab);
|
||||
const V a3 = hn::LoadU(d, A_tile + stride_a * 3 + col_ab);
|
||||
c30 = hn::ReorderWidenMulAccumulate(df, a3, b0, c30, unused_sum1);
|
||||
c31 = hn::ReorderWidenMulAccumulate(df, a3, b1, c31, unused_sum1);
|
||||
c32 = hn::ReorderWidenMulAccumulate(df, a3, b2, c32, unused_sum1);
|
||||
|
|
@ -270,10 +252,10 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
|||
// Ensure sum1 was indeed unused.
|
||||
HWY_DASSERT(hn::AllTrue(df, hn::Eq(unused_sum1, hn::Zero(df))));
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||
c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
|
||||
c32, c33, add, row_b_col_c, scale, C_tile, stride_c);
|
||||
}
|
||||
|
||||
#endif // GEMMA_NATIVE_BF16
|
||||
|
|
@ -295,19 +277,17 @@ HWY_INLINE void UpdateTileRow(const VF& a0, const VF& a1, const VF& b00,
|
|||
c3 = hn::MulAdd(a1, b31, c3);
|
||||
}
|
||||
|
||||
// Accumulates a single kNumRowsx4 tile of A x B into C. B is transposed, so we
|
||||
// can iterate over both A and B with consecutive vector loads. kNumRows<=4.
|
||||
// Accumulates a single kNumRows (<= 4) x 4 tile of A x B into C. B is
|
||||
// transposed, so we iterate over both A and B with consecutive vector loads.
|
||||
// General case: uses CompressTraits to load from A and B.
|
||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||
typename MatTB>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||
const MatTB* HWY_RESTRICT B,
|
||||
float* HWY_RESTRICT C,
|
||||
const float scale,
|
||||
template <size_t kNumRows, bool kAdd, typename MatTA, typename MatTB>
|
||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
||||
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
||||
float* HWY_RESTRICT C, const float scale,
|
||||
const float* HWY_RESTRICT add,
|
||||
const size_t idx_tile, const size_t xtiles,
|
||||
const size_t stride_a, const size_t stride_b,
|
||||
const size_t stride_c) {
|
||||
const size_t cols_a, const size_t stride_a,
|
||||
const size_t stride_b, const size_t stride_c) {
|
||||
constexpr size_t kRegRows = 4;
|
||||
constexpr size_t kRegCols = 4;
|
||||
static_assert(kNumRows <= kRegRows);
|
||||
|
|
@ -343,8 +323,8 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
V c32 = hn::Zero(d32);
|
||||
V c33 = hn::Zero(d32);
|
||||
|
||||
const size_t tile_a_ofs = stride_a * row_a;
|
||||
const size_t tile_b_ofs = stride_b * row_b_col_c;
|
||||
const size_t A_tile_ofs = A_ofs + stride_a * row_a;
|
||||
const size_t B_tile_ofs = B_ofs + stride_b * row_b_col_c;
|
||||
|
||||
// Loop over columns of A and columns of the transposed B, in steps of 2*N
|
||||
// (since we are decoding consecutive bytes at each iteration).
|
||||
|
|
@ -352,69 +332,74 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
|||
size_t col_ab = 0;
|
||||
|
||||
HWY_UNROLL(1)
|
||||
for (; col_ab <= kColsA_RowsB - 2 * N; col_ab += 2 * N) {
|
||||
for (; col_ab <= cols_a - 2 * N; col_ab += 2 * N) {
|
||||
V b00, b01;
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 0 + col_ab, b00, b01);
|
||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 0 + col_ab, b00, b01);
|
||||
V b10, b11;
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 1 + col_ab, b10, b11);
|
||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 1 + col_ab, b10, b11);
|
||||
V b20, b21;
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 2 + col_ab, b20, b21);
|
||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 2 + col_ab, b20, b21);
|
||||
V b30, b31;
|
||||
TraitsB::Decompress2(d32, B, tile_b_ofs + stride_b * 3 + col_ab, b30, b31);
|
||||
TraitsB::Decompress2(d32, B, B_tile_ofs + stride_b * 3 + col_ab, b30, b31);
|
||||
|
||||
V a00, a01;
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 0 + col_ab, a00, a01);
|
||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 0 + col_ab, a00, a01);
|
||||
UpdateTileRow(a00, a01, b00, b01, b10, b11, b20, b21, b30, b31, c00, c01,
|
||||
c02, c03);
|
||||
if constexpr (kNumRows == 1) continue;
|
||||
|
||||
V a10, a11;
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 1 + col_ab, a10, a11);
|
||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 1 + col_ab, a10, a11);
|
||||
UpdateTileRow(a10, a11, b00, b01, b10, b11, b20, b21, b30, b31, c10, c11,
|
||||
c12, c13);
|
||||
if constexpr (kNumRows == 2) continue;
|
||||
|
||||
V a20, a21;
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 2 + col_ab, a20, a21);
|
||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 2 + col_ab, a20, a21);
|
||||
UpdateTileRow(a20, a21, b00, b01, b10, b11, b20, b21, b30, b31, c20, c21,
|
||||
c22, c23);
|
||||
if constexpr (kNumRows == 3) continue;
|
||||
|
||||
V a30, a31;
|
||||
TraitsA::Decompress2(d32, A, tile_a_ofs + stride_a * 3 + col_ab, a30, a31);
|
||||
TraitsA::Decompress2(d32, A, A_tile_ofs + stride_a * 3 + col_ab, a30, a31);
|
||||
UpdateTileRow(a30, a31, b00, b01, b10, b11, b20, b21, b30, b31, c30, c31,
|
||||
c32, c33);
|
||||
}
|
||||
|
||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||
float* HWY_RESTRICT C_tile = C + stride_c * row_a + row_b_col_c;
|
||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||
c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
|
||||
c32, c33, add, row_b_col_c, scale, C_tile, stride_c);
|
||||
}
|
||||
|
||||
// C = A * B * scale [+ add].
|
||||
// Computes the matrix product of A and B and stores this in C.
|
||||
// If kAdd is true, the row-vector `add` is added to each row of C.
|
||||
// A is a matrix of size (batch_size, kColsA_RowsB).
|
||||
// Tiled 4x4 GEMM: C = A * B * scale [+ add].
|
||||
// Computes the matrix product of A and B and stores this in C. Processes tiles
|
||||
// of 4x4 vectors in parallel with a work-stealing thread pool.
|
||||
//
|
||||
// If kAdd is true, the row-vector `add` is added to each row of C, otherwise
|
||||
// `add` is ignored and can be nullptr.
|
||||
// A is a row-major matrix of size (batch_size, kColsA_RowsB).
|
||||
// B is passed transposed (column-major), so a matrix of size
|
||||
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
|
||||
// A_ofs and B_ofs are offsets into A and B, respectively; they remain separate
|
||||
// from the pointers because some MatTA/B such as NuqStream do not support
|
||||
// pointer arithmetic.
|
||||
// C is a matrix of size (batch_size, kColsBC).
|
||||
// The product is scaled by `scale` to support CompressedArray with scale != 1,
|
||||
// the caller can pass the product of the scales of A and B.
|
||||
// A scale for `add` is not supported, so make sure its scale is 1.
|
||||
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||
// and kColsBC is 24k or 3k.
|
||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||
// Typically batch_size is 1..512, kColsA_RowsB and kColsBC are 3k or 24k.
|
||||
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
||||
typename MatTB, typename OutT, typename AddT>
|
||||
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||
float scale, OutT* HWY_RESTRICT C, const AddT* HWY_RESTRICT add,
|
||||
hwy::ThreadPool& pool) {
|
||||
typename MatTB, typename OutT>
|
||||
HWY_NOINLINE void MatMul_4x4(const size_t batch_size,
|
||||
const MatTA* HWY_RESTRICT A, const size_t A_ofs,
|
||||
const MatTB* HWY_RESTRICT B, const size_t B_ofs,
|
||||
const float scale, OutT* HWY_RESTRICT C,
|
||||
const float* HWY_RESTRICT add,
|
||||
hwy::ThreadPool& pool) {
|
||||
PROFILER_ZONE("Matmul");
|
||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||
// so that we finish one L3-sized piece of C at a time.
|
||||
// We currently write C directly, which touches more memory than fits in L3.
|
||||
// TODO: add another level of loops to finish L3-sized pieces of C at a time.
|
||||
const hn::ScalableTag<MatTA> d;
|
||||
const size_t N = Lanes(d);
|
||||
constexpr size_t kRegRows = 4;
|
||||
|
|
@ -436,38 +421,28 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
|||
HWY_ASSERT(num_rows > 0);
|
||||
switch (num_rows) {
|
||||
case 1:
|
||||
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||
kTilesX, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
GEMM_4x4_Tile<1, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
break;
|
||||
case 2:
|
||||
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||
kTilesX, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
GEMM_4x4_Tile<2, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
break;
|
||||
case 3:
|
||||
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||
kTilesX, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
GEMM_4x4_Tile<3, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
break;
|
||||
default:
|
||||
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||
kTilesX, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
GEMM_4x4_Tile<4, kAdd>(A, A_ofs, B, B_ofs, C, scale, add, idx_tile,
|
||||
kTilesX, kColsA_RowsB, kStrideA, kStrideB,
|
||||
kStrideC);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// As above, without the add.
|
||||
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
||||
typename MatTB, typename OutT>
|
||||
HWY_NOINLINE void MatMul_4x4_Batch(
|
||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||
float scale, OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
||||
MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>(
|
||||
batch_size, A, B, scale, C, /*add=*/static_cast<OutT*>(nullptr), pool);
|
||||
}
|
||||
|
||||
//------------------------------------------------------------------------------
|
||||
HWY_INLINE void ToEvenOddF32(const hwy::bfloat16_t* HWY_RESTRICT vec_aligned,
|
||||
const size_t size, float* HWY_RESTRICT out) {
|
||||
|
|
@ -525,6 +500,22 @@ HWY_INLINE void TwoOfsMatVecAddLoop(const ArrayT& mat, const size_t mat_ofs0,
|
|||
}
|
||||
}
|
||||
|
||||
HWY_INLINE constexpr size_t MaxCols() {
|
||||
// Vec + mat rows should fit into 32 KiB L1.
|
||||
return 2048;
|
||||
}
|
||||
|
||||
template <size_t kOuter>
|
||||
HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||
// Aim for 128 work items to reduce pool overhead. Must be at least one
|
||||
// vector; prefer a power of two for faster division.
|
||||
constexpr size_t kLanes = hn::ScalableTag<float>().MaxLanes();
|
||||
constexpr size_t kRowsPerStrip =
|
||||
kOuter < 128 ? kLanes
|
||||
: HWY_MAX(kLanes, 1ULL << hwy::FloorLog2(kOuter / 128));
|
||||
return kRowsPerStrip;
|
||||
}
|
||||
|
||||
namespace detail {
|
||||
|
||||
// For each i = [0, num_rows), compute partial (length `num_cols`) dot product
|
||||
|
|
|
|||
|
|
@ -301,15 +301,10 @@ void TestTiledBatchMatMul() {
|
|||
|
||||
const double start_tiled = hwy::platform::Now();
|
||||
EXPECT_EQ(scale, a->scale() * b_trans->scale());
|
||||
if (kAdd) {
|
||||
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), scale,
|
||||
c.get(), add->data(), pool);
|
||||
} else {
|
||||
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), scale, c.get(),
|
||||
pool);
|
||||
}
|
||||
MatMul_4x4<kN, kK, kAdd>(kM, a->data(), 0, b_trans->data(), 0, scale, c.get(),
|
||||
add->data(), pool);
|
||||
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
||||
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);
|
||||
fprintf(stderr, "MatMul_4x4 took %f seconds.\n", tiled_matmul_seconds);
|
||||
|
||||
AssertClose(c_slow->data(), c.get(), kM * kK);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue