Add scale parameter to MatMul.

Add accessor to CompressedArray that asserts the scale is 1 and use it.

PiperOrigin-RevId: 653604840
This commit is contained in:
Daniel Keysers 2024-07-18 06:58:15 -07:00 committed by Copybara-Service
parent 5a751a9a44
commit e87e65ca45
5 changed files with 148 additions and 104 deletions

View File

@ -80,6 +80,14 @@ class CompressedArray {
// may be different from 1.0f. // may be different from 1.0f.
MatT* data() { return data_.data(); } MatT* data() { return data_.data(); }
const MatT* data() const { return data_.data(); } const MatT* data() const { return data_.data(); }
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
// A scale of 0 indicates that the scale has likely never been set, so is
// "implicitly 1".
const MatT* data_scale1() const {
HWY_ASSERT(scale() == 1.f || scale() == 0.f);
return data_.data();
}
// Decoded elements should be multiplied by this to restore their original // Decoded elements should be multiplied by this to restore their original
// range. This is required because SfpStream can only encode a limited range // range. This is required because SfpStream can only encode a limited range

View File

@ -81,9 +81,9 @@ HWY_NOINLINE void GriffinRecurrent(
TwoMatVecAdd<kModelDim, kModelDim>( TwoMatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0, layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
activations.pre_att_rms_out.Batch(batch_idx), activations.pre_att_rms_out.Batch(batch_idx),
/*add0=*/layer_weights->griffin.linear_x_biases.data(), /*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(),
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x, /*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(),
/*out1=*/y, pool); /*out0=*/x, /*out1=*/y, pool);
Gelu(y, kModelDim); Gelu(y, kModelDim);
} }
@ -106,13 +106,13 @@ HWY_NOINLINE void GriffinRecurrent(
for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) { for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i); auto xv = hn::Load(df, x + i);
auto accum0 = auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data() + i); hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i);
auto accum1 = hn::Zero(df); auto accum1 = hn::Zero(df);
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even"); static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
for (size_t l = 0; 2 * l < kConv1dWidth; l++) { for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() + auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
(kConv1dWidth - 1 - 2 * l) * kModelDim + i); (kConv1dWidth - 1 - 2 * l) * kModelDim + i);
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() + auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
(kConv1dWidth - 2 - 2 * l) * kModelDim + i); (kConv1dWidth - 2 - 2 * l) * kModelDim + i);
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0); accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1); accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
@ -139,16 +139,18 @@ HWY_NOINLINE void GriffinRecurrent(
TwoOfsMatVecAddLoop<kHeadDim, kHeadDim>( TwoOfsMatVecAddLoop<kHeadDim, kHeadDim>(
layer_weights->griffin.gate_w, kMatrixSize * head, layer_weights->griffin.gate_w, kMatrixSize * head,
kMatrixSize * (kHeads + head), x + head_offset, kMatrixSize * (kHeads + head), x + head_offset,
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset, /*add0=*/layer_weights->griffin.gate_biases.data_scale1() +
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
head_offset, head_offset,
/*add1=*/layer_weights->griffin.gate_biases.data_scale1() +
kModelDim + head_offset,
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset); /*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
Sigmoid(gate_x + head_offset, kHeadDim); Sigmoid(gate_x + head_offset, kHeadDim);
Sigmoid(a + head_offset, kHeadDim); Sigmoid(a + head_offset, kHeadDim);
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x) const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
HWY_ATTR { return hn::Mul(x, gate_x); }; HWY_ATTR { return hn::Mul(x, gate_x); };
hn::Transform1(D(), a + head_offset, kHeadDim, hn::Transform1(D(), a + head_offset, kHeadDim,
layer_weights->griffin.a.data() + head_offset, fn_mul); layer_weights->griffin.a.data_scale1() + head_offset,
fn_mul);
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset, hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
fn_mul); fn_mul);
// RNN scan // RNN scan
@ -180,7 +182,7 @@ HWY_NOINLINE void GriffinRecurrent(
float* out_ptr = activations.att_post2.Batch(batch_idx); float* out_ptr = activations.att_post2.Batch(batch_idx);
MatVecAdd<kModelDim, kModelDim>( MatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x, layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data(), layer_weights->griffin.linear_out_biases.data_scale1(),
activations.even_odd.All(), out_ptr, pool); activations.even_odd.All(), out_ptr, pool);
} }
} }
@ -222,9 +224,10 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
// //
// Compute Q only or QKV (if MHA). // Compute Q only or QKV (if MHA).
// If MHA, this also computes KV, which we copy to the KV cache below. // If MHA, this also computes KV, which we copy to the KV cache below.
const float scale = layer_weights->qkv_einsum_w.scale();
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>( MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
num_tokens_and_queries, activations.pre_att_rms_out.All(), num_tokens_and_queries, activations.pre_att_rms_out.All(),
layer_weights->qkv_einsum_w.data(), activations.q.All(), pool); layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
// Compute KV if not MHA. // Compute KV if not MHA.
if constexpr (!kIsMHA) { if constexpr (!kIsMHA) {
@ -342,7 +345,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim]. // attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>( MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
layer_weights->attn_vec_einsum_w, 0, att_out, layer_weights->attn_vec_einsum_w, 0, att_out,
layer_weights->attention_output_biases.data(), layer_weights->attention_output_biases.data_scale1(),
activations.even_odd.All(), layer_out, pool); activations.even_odd.All(), layer_out, pool);
// Head 1 and following are added to layer_out. // Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) { for (size_t head = 1; head < kHeads; ++head) {
@ -384,34 +387,30 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
constexpr size_t kColsB = kFFHiddenDim; constexpr size_t kColsB = kFFHiddenDim;
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize()); HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
const auto A = activations.bf_pre_ffw_rms_out.All(); const auto A = activations.bf_pre_ffw_rms_out.All();
const float scale = layer_weights->gating_einsum_w.scale();
const auto B1 = layer_weights->gating_einsum_w.data(); const auto B1 = layer_weights->gating_einsum_w.data();
const auto B2 = B1 + kColsA * kColsB; const auto B2 = B1 + kColsA * kColsB;
auto C1 = activations.C1.All(); auto C1 = activations.C1.All();
auto C2 = activations.C2.All(); auto C2 = activations.C2.All();
constexpr bool kAddBias = TConfig::kFFBiases; constexpr bool kAddBias = TConfig::kFFBiases;
const auto bias = layer_weights->ffw_gating_biases.data(); const auto bias1 = layer_weights->ffw_gating_biases.data_scale1();
const auto bias2 = bias1 + kFFHiddenDim;
// Will go through GELU. // Will go through GELU.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, C1, MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, scale, C1,
bias, pool); bias1, pool);
// What to multiply by. // What to multiply by.
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, C2, MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, scale, C2,
bias + kFFHiddenDim, pool); bias2, pool);
// Activation (Gelu) and multiply by gate. Store activations in C1. // Activation (Gelu) and multiply by gate. Store activations in C1.
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens); Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
// linear_w may have a scale value different from 1, apply that here.
// We multiply all activations by the scale value to compensate for the
// missing scale value in the weights.
if (layer_weights->linear_w.scale() != 1.0f) {
MulByConst(layer_weights->linear_w.scale(), C1, kFFHiddenDim * num_tokens);
}
// Hidden layer -> output layer. // Hidden layer -> output layer.
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>( MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
num_tokens, C1, layer_weights->linear_w.data(), activations.ffw_out.All(), num_tokens, C1, layer_weights->linear_w.data(),
layer_weights->ffw_output_biases.data(), pool); layer_weights->linear_w.scale(), activations.ffw_out.All(),
layer_weights->ffw_output_biases.data_scale1(), pool);
} }
// TODO: pass Activations.x instead of Activations. // TODO: pass Activations.x instead of Activations.
@ -453,7 +452,7 @@ HWY_NOINLINE void TransformerLayer(
size_t layer_of_type = size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNormBatched(num_tokens_and_queries, activations.x.All(), RMSNormBatched(num_tokens_and_queries, activations.x.All(),
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data_scale1(),
activations.pre_att_rms_out.All(), kModelDim); activations.pre_att_rms_out.All(), kModelDim);
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations, Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
@ -472,21 +471,22 @@ HWY_NOINLINE void TransformerLayer(
} }
if (TConfig::kPostNorm == PostNormType::Scale) { if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched(num_tokens_and_queries, RMSNormInplaceBatched(
layer_weights->post_attention_norm_scale.data(), num_tokens_and_queries,
activations.att_post2.All(), kModelDim); layer_weights->post_attention_norm_scale.data_scale1(),
activations.att_post2.All(), kModelDim);
} }
ResidualConnection<TConfig>(num_tokens_and_queries, ResidualConnection<TConfig>(num_tokens_and_queries,
activations.att_post2.All(), activations.x.All(), activations.att_post2.All(), activations.x.All(),
layer_weights, /*is_attention=*/true); layer_weights, /*is_attention=*/true);
RMSNormBatched(num_tokens_and_queries, activations.x.All(), RMSNormBatched(num_tokens_and_queries, activations.x.All(),
layer_weights->pre_ffw_norm_scale.data(), layer_weights->pre_ffw_norm_scale.data_scale1(),
activations.bf_pre_ffw_rms_out.All(), kModelDim); activations.bf_pre_ffw_rms_out.All(), kModelDim);
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool); FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
if (TConfig::kPostNorm == PostNormType::Scale) { if (TConfig::kPostNorm == PostNormType::Scale) {
RMSNormInplaceBatched(num_tokens_and_queries, RMSNormInplaceBatched(num_tokens_and_queries,
layer_weights->post_ffw_norm_scale.data(), layer_weights->post_ffw_norm_scale.data_scale1(),
activations.ffw_out.All(), kModelDim); activations.ffw_out.All(), kModelDim);
} }
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(), ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
@ -564,7 +564,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
} }
} }
RMSNormInplaceBatched(num_tokens_and_queries, weights.final_norm_scale.data(), RMSNormInplaceBatched(num_tokens_and_queries,
weights.final_norm_scale.data_scale1(),
activations.x.All(), kModelDim); activations.x.All(), kModelDim);
if (layers_output) { if (layers_output) {
for (size_t token_idx = 0; token_idx < num_tokens_and_queries; for (size_t token_idx = 0; token_idx < num_tokens_and_queries;

View File

@ -27,6 +27,7 @@
#include <random> #include <random>
#include <type_traits> // std::enable_if_t #include <type_traits> // std::enable_if_t
#include "compression/compress.h"
#include "compression/sfp.h" #include "compression/sfp.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
@ -106,47 +107,50 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
// Shared between f32 and bf16, which also accumulates into f32 vectors. // Shared between f32 and bf16, which also accumulates into f32 vectors.
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>> template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03, HWY_INLINE void StoreHorizontalSums(DF df, //
VF c00, VF c01, VF c02, VF c03, //
VF c10, VF c11, VF c12, VF c13, // VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, // VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33, VF c30, VF c31, VF c32, VF c33, //
float* HWY_RESTRICT tile_c, float scale, float* HWY_RESTRICT tile_c,
size_t stride_c) { size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in // Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs. // expensive, but only a fraction of the kColsA_RowsB/N FMAs.
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00); tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01); tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02); tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03); tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03);
if (kNumRows == 1) return; if (kNumRows == 1) return;
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10); tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10);
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11); tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11);
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12); tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12);
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13); tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13);
if (kNumRows == 2) return; if (kNumRows == 2) return;
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20); tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20);
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21); tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21);
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22); tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22);
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23); tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23);
if (kNumRows == 3) return; if (kNumRows == 3) return;
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30); tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30);
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31); tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31);
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32); tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32);
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33); tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33);
} }
// Completes the tile by summing across the vectors, and adds the biases. // Completes the tile by summing across the vectors, and adds the biases.
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>> template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03, HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
VF c00, VF c01, VF c02, VF c03, //
VF c10, VF c11, VF c12, VF c13, // VF c10, VF c11, VF c12, VF c13, //
VF c20, VF c21, VF c22, VF c23, // VF c20, VF c21, VF c22, VF c23, //
VF c30, VF c31, VF c32, VF c33, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add, const float* HWY_RESTRICT add,
const float scale,
float* HWY_RESTRICT tile_c, float* HWY_RESTRICT tile_c,
size_t stride_c) { size_t stride_c) {
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles. // We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
@ -154,31 +158,31 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is // the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
// expensive, but only a fraction of the kColsA_RowsB/N FMAs. // expensive, but only a fraction of the kColsA_RowsB/N FMAs.
float addon0 = hwy::ConvertScalarTo<float>(add[0]); float addon0 = hwy::ConvertScalarTo<float>(add[0]);
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00) + addon0; tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + addon0;
float addon1 = hwy::ConvertScalarTo<float>(add[1]); float addon1 = hwy::ConvertScalarTo<float>(add[1]);
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01) + addon1; tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + addon1;
float addon2 = hwy::ConvertScalarTo<float>(add[2]); float addon2 = hwy::ConvertScalarTo<float>(add[2]);
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02) + addon2; tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + addon2;
float addon3 = hwy::ConvertScalarTo<float>(add[3]); float addon3 = hwy::ConvertScalarTo<float>(add[3]);
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03) + addon3; tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + addon3;
if (kNumRows == 1) return; if (kNumRows == 1) return;
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10) + addon0; tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + addon0;
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11) + addon1; tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + addon1;
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12) + addon2; tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + addon2;
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13) + addon3; tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + addon3;
if (kNumRows == 2) return; if (kNumRows == 2) return;
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20) + addon0; tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + addon0;
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21) + addon1; tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + addon1;
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22) + addon2; tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + addon2;
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23) + addon3; tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + addon3;
if (kNumRows == 3) return; if (kNumRows == 3) return;
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30) + addon0; tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + addon0;
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31) + addon1; tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + addon1;
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32) + addon2; tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + addon2;
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3; tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + addon3;
} }
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call // Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
@ -188,16 +192,16 @@ template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
HWY_INLINE void StoreHorizontalSumsMaybeAdd( HWY_INLINE void StoreHorizontalSumsMaybeAdd(
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13, DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33, VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
const float* HWY_RESTRICT add, size_t add_offset, const float* HWY_RESTRICT add, size_t add_offset, const float scale,
float* HWY_RESTRICT tile_c, size_t stride_c) { float* HWY_RESTRICT tile_c, size_t stride_c) {
if constexpr (kAdd) { if constexpr (kAdd) {
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13, StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, c20, c21, c22, c23, c30, c31, c32, c33,
add + add_offset, tile_c, stride_c); add + add_offset, scale, tile_c, stride_c);
} else { } else {
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13, StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
c20, c21, c22, c23, c30, c31, c32, c33, c20, c21, c22, c23, c30, c31, c32, c33,
tile_c, stride_c); scale, tile_c, stride_c);
} }
} }
@ -215,7 +219,9 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd> template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
const hwy::bfloat16_t* HWY_RESTRICT B, const hwy::bfloat16_t* HWY_RESTRICT B,
float* HWY_RESTRICT C, const float* add, float* HWY_RESTRICT C,
const float scale,
const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles, const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b, const size_t stride_a, const size_t stride_b,
const size_t stride_c) { const size_t stride_c) {
@ -303,7 +309,7 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>( StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, tile_c, stride_c); c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
} }
#endif // GEMMA_NATIVE_BF16 #endif // GEMMA_NATIVE_BF16
@ -332,7 +338,9 @@ template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
typename MatTB> typename MatTB>
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A, HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
const MatTB* HWY_RESTRICT B, const MatTB* HWY_RESTRICT B,
float* HWY_RESTRICT C, const float* add, float* HWY_RESTRICT C,
const float scale,
const float* HWY_RESTRICT add,
const size_t idx_tile, const size_t xtiles, const size_t idx_tile, const size_t xtiles,
const size_t stride_a, const size_t stride_b, const size_t stride_a, const size_t stride_b,
const size_t stride_c) { const size_t stride_c) {
@ -417,18 +425,28 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c; float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>( StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31, d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
c32, c33, add, row_b_col_c, tile_c, stride_c); c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
} }
// C = A * B * scale [+ add].
// Computes the matrix product of A and B and stores this in C.
// If kAdd is true, the row-vector `add` is added to each row of C.
// A is a matrix of size (batch_size, kColsA_RowsB).
// B is passed transposed (column-major), so a matrix of size
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
// C is a matrix of size (batch_size, kColsBC).
// The product is scaled by `scale` to support CompressedArray with scale != 1,
// the caller can pass the product of the scales of A and B.
// A scale for `add` is not supported, so make sure its scale is 1.
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k, // Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
// and kColsBC is 24k or 3k. Note: B is transposed (column-major). // and kColsBC is 24k or 3k.
// NOTE that batch_size is the number of rows of A and C.
// This function processes tiles in parallel with a work-stealing thread pool. // This function processes tiles in parallel with a work-stealing thread pool.
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA, template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
typename MatTB, typename OutT, typename AddT> typename MatTB, typename OutT, typename AddT>
HWY_NOINLINE void MatMul_4x4_Batch_Add( HWY_NOINLINE void MatMul_4x4_Batch_Add(
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) { float scale, OutT* HWY_RESTRICT C, const AddT* HWY_RESTRICT add,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Matmul"); PROFILER_ZONE("Matmul");
// Process reg-sized tiles of C in parallel. We currently write C directly, // Process reg-sized tiles of C in parallel. We currently write C directly,
// which touches more memory than fits in L3. TODO: add another level of loops // which touches more memory than fits in L3. TODO: add another level of loops
@ -454,31 +472,36 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add(
HWY_ASSERT(num_rows > 0); HWY_ASSERT(num_rows > 0);
switch (num_rows) { switch (num_rows) {
case 1: case 1:
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kStrideA, kStrideB, kStrideC); kTilesX, kStrideA, kStrideB,
kStrideC);
break; break;
case 2: case 2:
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kStrideA, kStrideB, kStrideC); kTilesX, kStrideA, kStrideB,
kStrideC);
break; break;
case 3: case 3:
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kStrideA, kStrideB, kStrideC); kTilesX, kStrideA, kStrideB,
kStrideC);
break; break;
default: default:
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX, GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
kStrideA, kStrideB, kStrideC); kTilesX, kStrideA, kStrideB,
kStrideC);
} }
}); });
} }
// As above, without the add.
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA, template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
typename MatTB, typename OutT> typename MatTB, typename OutT>
HWY_NOINLINE void MatMul_4x4_Batch( HWY_NOINLINE void MatMul_4x4_Batch(
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B, size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) { float scale, OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>( MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>(
batch_size, A, B, C, /*add=*/static_cast<OutT*>(nullptr), pool); batch_size, A, B, scale, C, /*add=*/static_cast<OutT*>(nullptr), pool);
} }
//------------------------------------------------------------------------------ //------------------------------------------------------------------------------

View File

@ -32,10 +32,11 @@
#include "hwy/aligned_allocator.h" #include "hwy/aligned_allocator.h"
#include "hwy/base.h" #include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h" #include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/timer.h"
// clang-format off // clang-format off
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT #define HWY_TARGET_INCLUDE "gemma/ops_test.cc" // NOLINT
// clang-format on // clang-format on
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
#include "hwy/highway.h" #include "hwy/highway.h"
@ -362,7 +363,7 @@ CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
}); });
Compress(content, ws, mat, pool); Compress(content, ws, mat, pool);
mat.set_scale(1.0f); mat.set_scale(1.9f); // Arbitrary value, different from 1.
return mat; return mat;
} }
@ -377,7 +378,7 @@ CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
}); });
Compress(content, ws, mat, pool); Compress(content, ws, mat, pool);
mat.set_scale(1.0f); mat.set_scale(1.2f); // Arbitrary value, different from 1.
return mat; return mat;
} }
@ -400,7 +401,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool); pool);
mat->set_scale(1.0f); mat->set_scale(0.6f); // Arbitrary value, different from 1.
return mat; return mat;
} }
@ -423,7 +424,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool); pool);
mat->set_scale(1.0f); // Arbitrary value, different from 1, must match GenerateMatHeap.
mat->set_scale(0.6f);
return mat; return mat;
} }
@ -443,7 +445,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0, Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
pool); pool);
mat->set_scale(1.0f); mat->set_scale(1.2f); // Arbitrary value, different from 1.
return mat; return mat;
} }
@ -487,6 +489,7 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter); hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
HWY_ASSERT(uncompressed_mat && out); HWY_ASSERT(uncompressed_mat && out);
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner); Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) { for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
out[idx_row] = add[idx_row]; out[idx_row] = add[idx_row];
for (size_t idx_col = 0; idx_col < kInner; idx_col++) { for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
@ -513,14 +516,14 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
template <size_t kN, size_t kK, typename MatTA, typename MatTB, template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE_GT(MatTB, 1)> HWY_IF_T_SIZE_GT(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b, const float* add, const MatTB* HWY_RESTRICT b, const float scale,
float* HWY_RESTRICT out) { const float* add, float* HWY_RESTRICT out) {
for (size_t i = 0; i < batch_size; ++i) { for (size_t i = 0; i < batch_size; ++i) {
for (size_t k = 0; k < kN; ++k) { for (size_t k = 0; k < kN; ++k) {
for (size_t j = 0; j < kK; ++j) { for (size_t j = 0; j < kK; ++j) {
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]); const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]); const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
out[i * kK + j] += a1 * b1; out[i * kK + j] += scale * a1 * b1;
} }
} }
if (add != nullptr) { if (add != nullptr) {
@ -537,12 +540,13 @@ template <size_t kN, size_t kK, typename MatTA, typename MatTB,
HWY_IF_T_SIZE(MatTB, 1)> HWY_IF_T_SIZE(MatTB, 1)>
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a, HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
const MatTB* HWY_RESTRICT b_compr, const MatTB* HWY_RESTRICT b_compr,
const float* add, float* HWY_RESTRICT out) { const float scale, const float* add,
float* HWY_RESTRICT out) {
const hn::ScalableTag<float> d; const hn::ScalableTag<float> d;
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN); hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(), CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
kK * kN); kK * kN);
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out); MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
} }
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA, template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
@ -558,11 +562,13 @@ void TestTiledBatchMatMul() {
GenerateMatHeap<MatTB, kN, kK>(0, pool); GenerateMatHeap<MatTB, kN, kK>(0, pool);
std::unique_ptr<CompressedArray<float, kK>> add = std::unique_ptr<CompressedArray<float, kK>> add =
GenerateMatHeap<float, 1, kK>(0, pool); GenerateMatHeap<float, 1, kK>(0, pool);
add->set_scale(1.0f);
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow = std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
GenerateZeroMatHeap<float, kM, kK>(pool); GenerateZeroMatHeap<float, kM, kK>(pool);
const float scale = a->scale() * b->scale();
const double start_slow = hwy::platform::Now(); const double start_slow = hwy::platform::Now();
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
kAdd ? add->data() : nullptr, c_slow->data()); kAdd ? add->data() : nullptr, c_slow->data());
const double slow_matmul_seconds = hwy::platform::Now() - start_slow; const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds); fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
@ -572,11 +578,13 @@ void TestTiledBatchMatMul() {
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool); GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
const double start_tiled = hwy::platform::Now(); const double start_tiled = hwy::platform::Now();
EXPECT_EQ(scale, a->scale() * b_trans->scale());
if (kAdd) { if (kAdd) {
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), c.get(), MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), scale,
add->data(), pool); c.get(), add->data(), pool);
} else { } else {
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool); MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), scale, c.get(),
pool);
} }
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled; const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds); fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);

View File

@ -15,6 +15,7 @@
#include "gemma/weights.h" #include "gemma/weights.h"
#include <cstdio>
#include <cstdlib> #include <cstdlib>
#include "compression/compress.h" #include "compression/compress.h"
@ -96,6 +97,9 @@ class WeightLogger {
public: public:
template <size_t N> template <size_t N>
void operator()(const char* name, const CompressedArray<float, N>& tensor) { void operator()(const char* name, const CompressedArray<float, N>& tensor) {
if (tensor.scale() != 1.0f) {
printf("[scale=%f] ", tensor.scale());
}
LogVec(name, tensor.data(), N); LogVec(name, tensor.data(), N);
total_weights += N; total_weights += N;
} }