Add scale parameter to MatMul.

Add accessor to CompressedArray that asserts the scale is 1 and use it.

PiperOrigin-RevId: 653604840
This commit is contained in:
Daniel Keysers 2024-07-18 06:58:15 -07:00 committed by Copybara-Service
parent 5a751a9a44
commit e87e65ca45
5 changed files with 148 additions and 104 deletions

View File

@ -80,6 +80,14 @@ class CompressedArray {
// may be different from 1.0f.
MatT* data() { return data_.data(); }
const MatT* data() const { return data_.data(); }
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
// A scale of 0 indicates that the scale has likely never been set, so is
// "implicitly 1".
const MatT* data_scale1() const {
HWY_ASSERT(scale() == 1.f || scale() == 0.f);
return data_.data();
}
// Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range

View File

@ -81,9 +81,9 @@ HWY_NOINLINE void GriffinRecurrent(
TwoMatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
activations.pre_att_rms_out.Batch(batch_idx),
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
/*out1=*/y, pool);
/*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(),
/*out0=*/x, /*out1=*/y, pool);
Gelu(y, kModelDim);
}
@ -106,13 +106,13 @@ HWY_NOINLINE void GriffinRecurrent(
for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i);
auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
@ -139,16 +139,18 @@ HWY_NOINLINE void GriffinRecurrent(
TwoOfsMatVecAddLoop<kHeadDim, kHeadDim>(
layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
/*add0=*/layer_weights->griffin.gate_biases.data_scale1() +
head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data_scale1() +
kModelDim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.data() + head_offset, fn_mul);
layer_weights->griffin.a.data_scale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul);
// RNN scan
@ -180,7 +182,7 @@ HWY_NOINLINE void GriffinRecurrent(
float* out_ptr = activations.att_post2.Batch(batch_idx);
MatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(),
layer_weights->griffin.linear_out_biases.data_scale1(),
activations.even_odd.All(), out_ptr, pool);
}
}
@ -222,9 +224,10 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
//
// Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below.
const float scale = layer_weights->qkv_einsum_w.scale();
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
num_tokens_and_queries, activations.pre_att_rms_out.All(),
layer_weights->qkv_einsum_w.data(), activations.q.All(), pool);
layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
// Compute KV if not MHA.
if constexpr (!kIsMHA) {
@ -342,7 +345,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(),
layer_weights->attention_output_biases.data_scale1(),
activations.even_odd.All(), layer_out, pool);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) {
@ -384,34 +387,30 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
constexpr size_t kColsB = kFFHiddenDim;
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All();
const float scale = layer_weights->gating_einsum_w.scale();
const auto B1 = layer_weights->gating_einsum_w.data();
const auto B2 = B1 + kColsA * kColsB;
auto C1 = activations.C1.All();
auto C2 = activations.C2.All();
constexpr bool kAddBias = TConfig::kFFBiases;
const auto bias = layer_weights->ffw_gating_biases.data();
const auto bias1 = layer_weights->ffw_gating_biases.data_scale1();
const auto bias2 = bias1 + kFFHiddenDim;
// Will go through GELU.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, C1,
bias, pool);
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, scale, C1,
bias1, pool);
// What to multiply by.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, C2,
bias + kFFHiddenDim, pool);
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, scale, C2,
bias2, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
// linear_w may have a scale value different from 1, apply that here.
// We multiply all activations by the scale value to compensate for the
// missing scale value in the weights.
if (layer_weights->linear_w.scale() != 1.0f) {
MulByConst(layer_weights->linear_w.scale(), C1, kFFHiddenDim * num_tokens);
}
// Hidden layer -> output layer.
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
num_tokens, C1, layer_weights->linear_w.data(), activations.ffw_out.All(),
layer_weights->ffw_output_biases.data(), pool);
num_tokens, C1, layer_weights->linear_w.data(),
layer_weights->linear_w.scale(), activations.ffw_out.All(),
layer_weights->ffw_output_biases.data_scale1(), pool);
}
// TODO: pass Activations.x instead of Activations.
@ -453,7 +452,7 @@ HWY_NOINLINE void TransformerLayer(
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
layer_weights->pre_attention_norm_scale.data(),
layer_weights->pre_attention_norm_scale.data_scale1(),
activations.pre_att_rms_out.All(), kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
@ -472,21 +471,22 @@ HWY_NOINLINE void TransformerLayer(
}
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched(num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data(),
activations.att_post2.All(), kModelDim);
RMSNormInplaceBatched(
num_tokens_and_queries,
layer_weights->post_attention_norm_scale.data_scale1(),
activations.att_post2.All(), kModelDim);
}
ResidualConnection<TConfig>(num_tokens_and_queries,
activations.att_post2.All(), activations.x.All(),
layer_weights, /*is_attention=*/true);
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
layer_weights->pre_ffw_norm_scale.data(),
layer_weights->pre_ffw_norm_scale.data_scale1(),
activations.bf_pre_ffw_rms_out.All(), kModelDim);
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched(num_tokens_and_queries,
layer_weights->post_ffw_norm_scale.data(),
layer_weights->post_ffw_norm_scale.data_scale1(),
activations.ffw_out.All(), kModelDim);
}
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
@ -564,7 +564,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
}
}
RMSNormInplaceBatched(num_tokens_and_queries, weights.final_norm_scale.data(),
RMSNormInplaceBatched(num_tokens_and_queries,
weights.final_norm_scale.data_scale1(),
activations.x.All(), kModelDim);
if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;

View File

@ -27,6 +27,7 @@
#include <random>
#include <type_traits> // std::enable_if_t
#include "compression/compress.h"
#include "compression/sfp.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
@ -106,47 +107,50 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
// Shared between f32 and bf16, which also accumulates into f32 vectors.
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
HWY_INLINE void StoreHorizontalSums(DF df, //
VF c00, VF c01, VF c02, VF c03, //
VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33,
float* HWY_RESTRICT tile_c,
VF c30, VF c31, VF c32, VF c33, //
float scale, float* HWY_RESTRICT tile_c,
size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00);
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02);
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03);
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03);
if (kNumRows == 1) return;
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10);
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11);
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12);
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13);
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10);
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11);
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12);
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13);
if (kNumRows == 2) return;
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20);
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21);
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22);
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23);
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20);
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21);
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22);
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23);
if (kNumRows == 3) return;
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30);
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31);
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32);
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30);
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31);
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32);
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33);
}
// Completes the tile by summing across the vectors, and adds the biases.
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
VF c00, VF c01, VF c02, VF c03, //
VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add,
const float scale,
float* HWY_RESTRICT tile_c,
size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
@ -154,31 +158,31 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
float addon0 = hwy::ConvertScalarTo<float>(add[0]);
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00) + addon0;
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + addon0;
float addon1 = hwy::ConvertScalarTo<float>(add[1]);
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01) + addon1;
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + addon1;
float addon2 = hwy::ConvertScalarTo<float>(add[2]);
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02) + addon2;
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + addon2;
float addon3 = hwy::ConvertScalarTo<float>(add[3]);
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03) + addon3;
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + addon3;
if (kNumRows == 1) return;
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10) + addon0;
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11) + addon1;
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12) + addon2;
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13) + addon3;
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + addon0;
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + addon1;
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + addon2;
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + addon3;
if (kNumRows == 2) return;
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20) + addon0;
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21) + addon1;
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22) + addon2;
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23) + addon3;
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + addon0;
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + addon1;
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + addon2;
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + addon3;
if (kNumRows == 3) return;
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30) + addon0;
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31) + addon1;
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32) + addon2;
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3;
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + addon0;
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + addon1;
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + addon2;
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + addon3;
}
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
@ -188,16 +192,16 @@ template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add, size_t add_offset,
const float* HWY_RESTRICT add, size_t add_offset, const float scale,
float* HWY_RESTRICT tile_c, size_t stride_c) {
if constexpr (kAdd) {
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
add + add_offset, tile_c, stride_c);
add + add_offset, scale, tile_c, stride_c);
} else {
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33,
tile_c, stride_c);
scale, tile_c, stride_c);
}
}
@ -215,7 +219,9 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
const hwy::bfloat16_t* HWY_RESTRICT B,
float* HWY_RESTRICT C, const float* add,
float* HWY_RESTRICT C,
const float scale,
const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b,
const size_t stride_c) {
@ -303,7 +309,7 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, tile_c, stride_c);
c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
}
#endif // GEMMA_NATIVE_BF16
@ -332,7 +338,9 @@ template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
typename MatTB>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const float* add,
float* HWY_RESTRICT C,
const float scale,
const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b,
const size_t stride_c) {
@ -417,18 +425,28 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, tile_c, stride_c);
c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
}
// C = A * B * scale [+ add].
// Computes the matrix product of A and B and stores this in C.
// If kAdd is true, the row-vector `add` is added to each row of C.
// A is a matrix of size (batch_size, kColsA_RowsB).
// B is passed transposed (column-major), so a matrix of size
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
// C is a matrix of size (batch_size, kColsBC).
// The product is scaled by `scale` to support CompressedArray with scale != 1,
// the caller can pass the product of the scales of A and B.
// A scale for `add` is not supported, so make sure its scale is 1.
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
// NOTE that batch_size is the number of rows of A and C.
// and kColsBC is 24k or 3k.
// This function processes tiles in parallel with a work-stealing thread pool.
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
typename MatTB, typename OutT, typename AddT>
HWY_NOINLINE void MatMul_4x4_Batch_Add(
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
float scale, OutT* HWY_RESTRICT C, const AddT* HWY_RESTRICT add,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Matmul");
// Process reg-sized tiles of C in parallel. We currently write C directly,
// which touches more memory than fits in L3. TODO: add another level of loops
@ -454,31 +472,36 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add(
HWY_ASSERT(num_rows > 0);
switch (num_rows) {
case 1:
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kTilesX, kStrideA, kStrideB,
kStrideC);
break;
case 2:
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kTilesX, kStrideA, kStrideB,
kStrideC);
break;
case 3:
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kTilesX, kStrideA, kStrideB,
kStrideC);
break;
default:
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
kStrideA, kStrideB, kStrideC);
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kTilesX, kStrideA, kStrideB,
kStrideC);
}
});
}
// As above, without the add.
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
typename MatTB, typename OutT>
HWY_NOINLINE void MatMul_4x4_Batch(
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
float scale, OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>(
batch_size, A, B, C, /*add=*/static_cast<OutT*>(nullptr), pool);
batch_size, A, B, scale, C, /*add=*/static_cast<OutT*>(nullptr), pool);
}
//------------------------------------------------------------------------------

View File

@ -32,10 +32,11 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// clang-format off
#undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" // NOLINT
// clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h"
@ -362,7 +363,7 @@ CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
});
Compress(content, ws, mat, pool);
mat.set_scale(1.0f);
mat.set_scale(1.9f); // Arbitrary value, different from 1.
return mat;
}
@ -377,7 +378,7 @@ CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
});
Compress(content, ws, mat, pool);
mat.set_scale(1.0f);
mat.set_scale(1.2f); // Arbitrary value, different from 1.
return mat;
}
@ -400,7 +401,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool);
mat->set_scale(1.0f);
mat->set_scale(0.6f); // Arbitrary value, different from 1.
return mat;
}
@ -423,7 +424,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool);
mat->set_scale(1.0f);
// Arbitrary value, different from 1, must match GenerateMatHeap.
mat->set_scale(0.6f);
return mat;
}
@ -443,7 +445,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool);
mat->set_scale(1.0f);
mat->set_scale(1.2f); // Arbitrary value, different from 1.
return mat;
}
@ -487,6 +489,7 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
@ -513,14 +516,14 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b, const float* add,
float* HWY_RESTRICT out) {
const MatTB* HWY_RESTRICT b, const float scale,
const float* add, float* HWY_RESTRICT out) {
for (size_t i = 0; i < batch_size; ++i) {
for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1;
out[i * kK + j] += scale * a1 * b1;
}
}
if (add != nullptr) {
@ -537,12 +540,13 @@ template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_compr,
const float* add, float* HWY_RESTRICT out) {
const float scale, const float* add,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
}
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
@ -558,11 +562,13 @@ void TestTiledBatchMatMul() {
GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kK>> add =
GenerateMatHeap<float, 1, kK>(0, pool);
add->set_scale(1.0f);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool);
const float scale = a->scale() * b->scale();
const double start_slow = hwy::platform::Now();
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(),
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data());
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
@ -572,11 +578,13 @@ void TestTiledBatchMatMul() {
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
const double start_tiled = hwy::platform::Now();
EXPECT_EQ(scale, a->scale() * b_trans->scale());
if (kAdd) {
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), c.get(),
add->data(), pool);
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), scale,
c.get(), add->data(), pool);
} else {
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), scale, c.get(),
pool);
}
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);

View File

@ -15,6 +15,7 @@
#include "gemma/weights.h"
#include <cstdio>
#include <cstdlib>
#include "compression/compress.h"
@ -96,6 +97,9 @@ class WeightLogger {
public:
template <size_t N>
void operator()(const char* name, const CompressedArray<float, N>& tensor) {
if (tensor.scale() != 1.0f) {
printf("[scale=%f] ", tensor.scale());
}
LogVec(name, tensor.data(), N);
total_weights += N;
}