mirror of https://github.com/google/gemma.cpp.git
Add scale parameter to MatMul.
Add accessor to CompressedArray that asserts the scale is 1 and use it. PiperOrigin-RevId: 653604840
This commit is contained in:
parent
5a751a9a44
commit
e87e65ca45
|
|
@ -80,6 +80,14 @@ class CompressedArray {
|
||||||
// may be different from 1.0f.
|
// may be different from 1.0f.
|
||||||
MatT* data() { return data_.data(); }
|
MatT* data() { return data_.data(); }
|
||||||
const MatT* data() const { return data_.data(); }
|
const MatT* data() const { return data_.data(); }
|
||||||
|
// The const accessor data_scale1() asserts (!) that the scale is 1.0f, so
|
||||||
|
// calling it means "I am sure the scale is 1 and therefore ignore the scale".
|
||||||
|
// A scale of 0 indicates that the scale has likely never been set, so is
|
||||||
|
// "implicitly 1".
|
||||||
|
const MatT* data_scale1() const {
|
||||||
|
HWY_ASSERT(scale() == 1.f || scale() == 0.f);
|
||||||
|
return data_.data();
|
||||||
|
}
|
||||||
|
|
||||||
// Decoded elements should be multiplied by this to restore their original
|
// Decoded elements should be multiplied by this to restore their original
|
||||||
// range. This is required because SfpStream can only encode a limited range
|
// range. This is required because SfpStream can only encode a limited range
|
||||||
|
|
|
||||||
|
|
@ -81,9 +81,9 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
TwoMatVecAdd<kModelDim, kModelDim>(
|
TwoMatVecAdd<kModelDim, kModelDim>(
|
||||||
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
layer_weights->griffin.linear_x_w, layer_weights->griffin.linear_y_w, 0,
|
||||||
activations.pre_att_rms_out.Batch(batch_idx),
|
activations.pre_att_rms_out.Batch(batch_idx),
|
||||||
/*add0=*/layer_weights->griffin.linear_x_biases.data(),
|
/*add0=*/layer_weights->griffin.linear_x_biases.data_scale1(),
|
||||||
/*add1=*/layer_weights->griffin.linear_y_biases.data(), /*out0=*/x,
|
/*add1=*/layer_weights->griffin.linear_y_biases.data_scale1(),
|
||||||
/*out1=*/y, pool);
|
/*out0=*/x, /*out1=*/y, pool);
|
||||||
Gelu(y, kModelDim);
|
Gelu(y, kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -106,13 +106,13 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) {
|
for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) {
|
||||||
auto xv = hn::Load(df, x + i);
|
auto xv = hn::Load(df, x + i);
|
||||||
auto accum0 =
|
auto accum0 =
|
||||||
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
|
hn::Load(df, layer_weights->griffin.conv_biases.data_scale1() + i);
|
||||||
auto accum1 = hn::Zero(df);
|
auto accum1 = hn::Zero(df);
|
||||||
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
static_assert(kConv1dWidth % 2 == 0, "Conv width must be even");
|
||||||
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
for (size_t l = 0; 2 * l < kConv1dWidth; l++) {
|
||||||
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data() +
|
auto wv0 = hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
|
||||||
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
(kConv1dWidth - 1 - 2 * l) * kModelDim + i);
|
||||||
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data() +
|
auto wv1 = hn::Load(df, layer_weights->griffin.conv_w.data_scale1() +
|
||||||
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
(kConv1dWidth - 2 - 2 * l) * kModelDim + i);
|
||||||
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
accum0 = hn::MulAdd(wv0, hn::Load(df, cache[l * 2] + i), accum0);
|
||||||
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
accum1 = hn::MulAdd(wv1, hn::Load(df, cache[l * 2 + 1] + i), accum1);
|
||||||
|
|
@ -139,16 +139,18 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
TwoOfsMatVecAddLoop<kHeadDim, kHeadDim>(
|
TwoOfsMatVecAddLoop<kHeadDim, kHeadDim>(
|
||||||
layer_weights->griffin.gate_w, kMatrixSize * head,
|
layer_weights->griffin.gate_w, kMatrixSize * head,
|
||||||
kMatrixSize * (kHeads + head), x + head_offset,
|
kMatrixSize * (kHeads + head), x + head_offset,
|
||||||
/*add0=*/layer_weights->griffin.gate_biases.data() + head_offset,
|
/*add0=*/layer_weights->griffin.gate_biases.data_scale1() +
|
||||||
/*add1=*/layer_weights->griffin.gate_biases.data() + kModelDim +
|
|
||||||
head_offset,
|
head_offset,
|
||||||
|
/*add1=*/layer_weights->griffin.gate_biases.data_scale1() +
|
||||||
|
kModelDim + head_offset,
|
||||||
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
/*out0=*/gate_x + head_offset, /*out1=*/a + head_offset);
|
||||||
Sigmoid(gate_x + head_offset, kHeadDim);
|
Sigmoid(gate_x + head_offset, kHeadDim);
|
||||||
Sigmoid(a + head_offset, kHeadDim);
|
Sigmoid(a + head_offset, kHeadDim);
|
||||||
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
const auto fn_mul = [](D d, hn::Vec<D> x, hn::Vec<D> gate_x)
|
||||||
HWY_ATTR { return hn::Mul(x, gate_x); };
|
HWY_ATTR { return hn::Mul(x, gate_x); };
|
||||||
hn::Transform1(D(), a + head_offset, kHeadDim,
|
hn::Transform1(D(), a + head_offset, kHeadDim,
|
||||||
layer_weights->griffin.a.data() + head_offset, fn_mul);
|
layer_weights->griffin.a.data_scale1() + head_offset,
|
||||||
|
fn_mul);
|
||||||
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
hn::Transform1(D(), x + head_offset, kHeadDim, gate_x + head_offset,
|
||||||
fn_mul);
|
fn_mul);
|
||||||
// RNN scan
|
// RNN scan
|
||||||
|
|
@ -180,7 +182,7 @@ HWY_NOINLINE void GriffinRecurrent(
|
||||||
float* out_ptr = activations.att_post2.Batch(batch_idx);
|
float* out_ptr = activations.att_post2.Batch(batch_idx);
|
||||||
MatVecAdd<kModelDim, kModelDim>(
|
MatVecAdd<kModelDim, kModelDim>(
|
||||||
layer_weights->griffin.linear_out_w, 0, x,
|
layer_weights->griffin.linear_out_w, 0, x,
|
||||||
layer_weights->griffin.linear_out_biases.data(),
|
layer_weights->griffin.linear_out_biases.data_scale1(),
|
||||||
activations.even_odd.All(), out_ptr, pool);
|
activations.even_odd.All(), out_ptr, pool);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -222,9 +224,10 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
//
|
//
|
||||||
// Compute Q only or QKV (if MHA).
|
// Compute Q only or QKV (if MHA).
|
||||||
// If MHA, this also computes KV, which we copy to the KV cache below.
|
// If MHA, this also computes KV, which we copy to the KV cache below.
|
||||||
|
const float scale = layer_weights->qkv_einsum_w.scale();
|
||||||
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
MatMul_4x4_Batch<kModelDim, kHeads * kQStride>(
|
||||||
num_tokens_and_queries, activations.pre_att_rms_out.All(),
|
num_tokens_and_queries, activations.pre_att_rms_out.All(),
|
||||||
layer_weights->qkv_einsum_w.data(), activations.q.All(), pool);
|
layer_weights->qkv_einsum_w.data(), scale, activations.q.All(), pool);
|
||||||
|
|
||||||
// Compute KV if not MHA.
|
// Compute KV if not MHA.
|
||||||
if constexpr (!kIsMHA) {
|
if constexpr (!kIsMHA) {
|
||||||
|
|
@ -342,7 +345,7 @@ HWY_NOINLINE void Attention(size_t batch_and_query_start, size_t num_tokens,
|
||||||
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
|
||||||
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
MatVecT</*kAdd=*/TConfig::kSoftmaxAttnOutputBiases, kModelDim, kQKVDim>(
|
||||||
layer_weights->attn_vec_einsum_w, 0, att_out,
|
layer_weights->attn_vec_einsum_w, 0, att_out,
|
||||||
layer_weights->attention_output_biases.data(),
|
layer_weights->attention_output_biases.data_scale1(),
|
||||||
activations.even_odd.All(), layer_out, pool);
|
activations.even_odd.All(), layer_out, pool);
|
||||||
// Head 1 and following are added to layer_out.
|
// Head 1 and following are added to layer_out.
|
||||||
for (size_t head = 1; head < kHeads; ++head) {
|
for (size_t head = 1; head < kHeads; ++head) {
|
||||||
|
|
@ -384,34 +387,30 @@ HWY_NOINLINE void FFW(Activations& activations, size_t num_tokens,
|
||||||
constexpr size_t kColsB = kFFHiddenDim;
|
constexpr size_t kColsB = kFFHiddenDim;
|
||||||
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
|
HWY_DASSERT(num_tokens <= activations.bf_pre_ffw_rms_out.BatchSize());
|
||||||
const auto A = activations.bf_pre_ffw_rms_out.All();
|
const auto A = activations.bf_pre_ffw_rms_out.All();
|
||||||
|
const float scale = layer_weights->gating_einsum_w.scale();
|
||||||
const auto B1 = layer_weights->gating_einsum_w.data();
|
const auto B1 = layer_weights->gating_einsum_w.data();
|
||||||
const auto B2 = B1 + kColsA * kColsB;
|
const auto B2 = B1 + kColsA * kColsB;
|
||||||
auto C1 = activations.C1.All();
|
auto C1 = activations.C1.All();
|
||||||
auto C2 = activations.C2.All();
|
auto C2 = activations.C2.All();
|
||||||
constexpr bool kAddBias = TConfig::kFFBiases;
|
constexpr bool kAddBias = TConfig::kFFBiases;
|
||||||
const auto bias = layer_weights->ffw_gating_biases.data();
|
const auto bias1 = layer_weights->ffw_gating_biases.data_scale1();
|
||||||
|
const auto bias2 = bias1 + kFFHiddenDim;
|
||||||
|
|
||||||
// Will go through GELU.
|
// Will go through GELU.
|
||||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, C1,
|
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B1, scale, C1,
|
||||||
bias, pool);
|
bias1, pool);
|
||||||
// What to multiply by.
|
// What to multiply by.
|
||||||
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, C2,
|
MatMul_4x4_Batch_Add<kColsA, kColsB, kAddBias>(num_tokens, A, B2, scale, C2,
|
||||||
bias + kFFHiddenDim, pool);
|
bias2, pool);
|
||||||
|
|
||||||
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
// Activation (Gelu) and multiply by gate. Store activations in C1.
|
||||||
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
|
Activation<TConfig>(C1, C2, kFFHiddenDim * num_tokens);
|
||||||
|
|
||||||
// linear_w may have a scale value different from 1, apply that here.
|
|
||||||
// We multiply all activations by the scale value to compensate for the
|
|
||||||
// missing scale value in the weights.
|
|
||||||
if (layer_weights->linear_w.scale() != 1.0f) {
|
|
||||||
MulByConst(layer_weights->linear_w.scale(), C1, kFFHiddenDim * num_tokens);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hidden layer -> output layer.
|
// Hidden layer -> output layer.
|
||||||
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
MatMul_4x4_Batch_Add<kFFHiddenDim, kModelDim, kAddBias>(
|
||||||
num_tokens, C1, layer_weights->linear_w.data(), activations.ffw_out.All(),
|
num_tokens, C1, layer_weights->linear_w.data(),
|
||||||
layer_weights->ffw_output_biases.data(), pool);
|
layer_weights->linear_w.scale(), activations.ffw_out.All(),
|
||||||
|
layer_weights->ffw_output_biases.data_scale1(), pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: pass Activations.x instead of Activations.
|
// TODO: pass Activations.x instead of Activations.
|
||||||
|
|
@ -453,7 +452,7 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
size_t layer_of_type =
|
size_t layer_of_type =
|
||||||
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
|
||||||
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
||||||
layer_weights->pre_attention_norm_scale.data(),
|
layer_weights->pre_attention_norm_scale.data_scale1(),
|
||||||
activations.pre_att_rms_out.All(), kModelDim);
|
activations.pre_att_rms_out.All(), kModelDim);
|
||||||
if (type == LayerAttentionType::kGemma) {
|
if (type == LayerAttentionType::kGemma) {
|
||||||
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
|
Attention<TConfig>(pos, num_tokens, num_queries, layer_of_type, activations,
|
||||||
|
|
@ -472,21 +471,22 @@ HWY_NOINLINE void TransformerLayer(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||||
RMSNormInplaceBatched(num_tokens_and_queries,
|
RMSNormInplaceBatched(
|
||||||
layer_weights->post_attention_norm_scale.data(),
|
num_tokens_and_queries,
|
||||||
activations.att_post2.All(), kModelDim);
|
layer_weights->post_attention_norm_scale.data_scale1(),
|
||||||
|
activations.att_post2.All(), kModelDim);
|
||||||
}
|
}
|
||||||
|
|
||||||
ResidualConnection<TConfig>(num_tokens_and_queries,
|
ResidualConnection<TConfig>(num_tokens_and_queries,
|
||||||
activations.att_post2.All(), activations.x.All(),
|
activations.att_post2.All(), activations.x.All(),
|
||||||
layer_weights, /*is_attention=*/true);
|
layer_weights, /*is_attention=*/true);
|
||||||
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
RMSNormBatched(num_tokens_and_queries, activations.x.All(),
|
||||||
layer_weights->pre_ffw_norm_scale.data(),
|
layer_weights->pre_ffw_norm_scale.data_scale1(),
|
||||||
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
activations.bf_pre_ffw_rms_out.All(), kModelDim);
|
||||||
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
|
FFW<TConfig>(activations, num_tokens_and_queries, layer_weights, pool);
|
||||||
if (TConfig::kPostNorm == PostNormType::Scale) {
|
if (TConfig::kPostNorm == PostNormType::Scale) {
|
||||||
RMSNormInplaceBatched(num_tokens_and_queries,
|
RMSNormInplaceBatched(num_tokens_and_queries,
|
||||||
layer_weights->post_ffw_norm_scale.data(),
|
layer_weights->post_ffw_norm_scale.data_scale1(),
|
||||||
activations.ffw_out.All(), kModelDim);
|
activations.ffw_out.All(), kModelDim);
|
||||||
}
|
}
|
||||||
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
|
ResidualConnection<TConfig>(num_tokens_and_queries, activations.ffw_out.All(),
|
||||||
|
|
@ -564,7 +564,8 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RMSNormInplaceBatched(num_tokens_and_queries, weights.final_norm_scale.data(),
|
RMSNormInplaceBatched(num_tokens_and_queries,
|
||||||
|
weights.final_norm_scale.data_scale1(),
|
||||||
activations.x.All(), kModelDim);
|
activations.x.All(), kModelDim);
|
||||||
if (layers_output) {
|
if (layers_output) {
|
||||||
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
for (size_t token_idx = 0; token_idx < num_tokens_and_queries;
|
||||||
|
|
|
||||||
135
gemma/ops.h
135
gemma/ops.h
|
|
@ -27,6 +27,7 @@
|
||||||
#include <random>
|
#include <random>
|
||||||
#include <type_traits> // std::enable_if_t
|
#include <type_traits> // std::enable_if_t
|
||||||
|
|
||||||
|
#include "compression/compress.h"
|
||||||
#include "compression/sfp.h"
|
#include "compression/sfp.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
|
@ -106,47 +107,50 @@ HWY_INLINE constexpr size_t RowsPerStrip() {
|
||||||
|
|
||||||
// Shared between f32 and bf16, which also accumulates into f32 vectors.
|
// Shared between f32 and bf16, which also accumulates into f32 vectors.
|
||||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void StoreHorizontalSums(DF df, VF c00, VF c01, VF c02, VF c03,
|
HWY_INLINE void StoreHorizontalSums(DF df, //
|
||||||
|
VF c00, VF c01, VF c02, VF c03, //
|
||||||
VF c10, VF c11, VF c12, VF c13, //
|
VF c10, VF c11, VF c12, VF c13, //
|
||||||
VF c20, VF c21, VF c22, VF c23, //
|
VF c20, VF c21, VF c22, VF c23, //
|
||||||
VF c30, VF c31, VF c32, VF c33,
|
VF c30, VF c31, VF c32, VF c33, //
|
||||||
float* HWY_RESTRICT tile_c,
|
float scale, float* HWY_RESTRICT tile_c,
|
||||||
size_t stride_c) {
|
size_t stride_c) {
|
||||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||||
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
// Each entry of C[r,c] is a dot product of A.row and B.col, which reside in
|
||||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||||
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
||||||
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00);
|
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00);
|
||||||
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01);
|
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01);
|
||||||
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02);
|
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02);
|
||||||
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03);
|
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03);
|
||||||
if (kNumRows == 1) return;
|
if (kNumRows == 1) return;
|
||||||
|
|
||||||
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10);
|
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10);
|
||||||
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11);
|
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11);
|
||||||
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12);
|
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12);
|
||||||
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13);
|
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13);
|
||||||
if (kNumRows == 2) return;
|
if (kNumRows == 2) return;
|
||||||
|
|
||||||
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20);
|
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20);
|
||||||
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21);
|
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21);
|
||||||
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22);
|
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22);
|
||||||
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23);
|
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23);
|
||||||
if (kNumRows == 3) return;
|
if (kNumRows == 3) return;
|
||||||
|
|
||||||
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30);
|
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30);
|
||||||
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31);
|
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31);
|
||||||
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32);
|
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32);
|
||||||
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33);
|
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Completes the tile by summing across the vectors, and adds the biases.
|
// Completes the tile by summing across the vectors, and adds the biases.
|
||||||
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
template <size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
|
HWY_INLINE void StoreHorizontalSumsAdd(DF df, //
|
||||||
|
VF c00, VF c01, VF c02, VF c03, //
|
||||||
VF c10, VF c11, VF c12, VF c13, //
|
VF c10, VF c11, VF c12, VF c13, //
|
||||||
VF c20, VF c21, VF c22, VF c23, //
|
VF c20, VF c21, VF c22, VF c23, //
|
||||||
VF c30, VF c31, VF c32, VF c33,
|
VF c30, VF c31, VF c32, VF c33,
|
||||||
const float* HWY_RESTRICT add,
|
const float* HWY_RESTRICT add,
|
||||||
|
const float scale,
|
||||||
float* HWY_RESTRICT tile_c,
|
float* HWY_RESTRICT tile_c,
|
||||||
size_t stride_c) {
|
size_t stride_c) {
|
||||||
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
// We are computing the product of (4, 4N) * (4N, 4) = (4, 4) tiles.
|
||||||
|
|
@ -154,31 +158,31 @@ HWY_INLINE void StoreHorizontalSumsAdd(DF df, VF c00, VF c01, VF c02, VF c03,
|
||||||
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
// the lanes of `c$r$c`, so we store their horizontal sum (ReduceSum). This is
|
||||||
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
// expensive, but only a fraction of the kColsA_RowsB/N FMAs.
|
||||||
float addon0 = hwy::ConvertScalarTo<float>(add[0]);
|
float addon0 = hwy::ConvertScalarTo<float>(add[0]);
|
||||||
tile_c[stride_c * 0 + 0] = hn::ReduceSum(df, c00) + addon0;
|
tile_c[stride_c * 0 + 0] = scale * hn::ReduceSum(df, c00) + addon0;
|
||||||
float addon1 = hwy::ConvertScalarTo<float>(add[1]);
|
float addon1 = hwy::ConvertScalarTo<float>(add[1]);
|
||||||
tile_c[stride_c * 0 + 1] = hn::ReduceSum(df, c01) + addon1;
|
tile_c[stride_c * 0 + 1] = scale * hn::ReduceSum(df, c01) + addon1;
|
||||||
float addon2 = hwy::ConvertScalarTo<float>(add[2]);
|
float addon2 = hwy::ConvertScalarTo<float>(add[2]);
|
||||||
tile_c[stride_c * 0 + 2] = hn::ReduceSum(df, c02) + addon2;
|
tile_c[stride_c * 0 + 2] = scale * hn::ReduceSum(df, c02) + addon2;
|
||||||
float addon3 = hwy::ConvertScalarTo<float>(add[3]);
|
float addon3 = hwy::ConvertScalarTo<float>(add[3]);
|
||||||
tile_c[stride_c * 0 + 3] = hn::ReduceSum(df, c03) + addon3;
|
tile_c[stride_c * 0 + 3] = scale * hn::ReduceSum(df, c03) + addon3;
|
||||||
if (kNumRows == 1) return;
|
if (kNumRows == 1) return;
|
||||||
|
|
||||||
tile_c[stride_c * 1 + 0] = hn::ReduceSum(df, c10) + addon0;
|
tile_c[stride_c * 1 + 0] = scale * hn::ReduceSum(df, c10) + addon0;
|
||||||
tile_c[stride_c * 1 + 1] = hn::ReduceSum(df, c11) + addon1;
|
tile_c[stride_c * 1 + 1] = scale * hn::ReduceSum(df, c11) + addon1;
|
||||||
tile_c[stride_c * 1 + 2] = hn::ReduceSum(df, c12) + addon2;
|
tile_c[stride_c * 1 + 2] = scale * hn::ReduceSum(df, c12) + addon2;
|
||||||
tile_c[stride_c * 1 + 3] = hn::ReduceSum(df, c13) + addon3;
|
tile_c[stride_c * 1 + 3] = scale * hn::ReduceSum(df, c13) + addon3;
|
||||||
if (kNumRows == 2) return;
|
if (kNumRows == 2) return;
|
||||||
|
|
||||||
tile_c[stride_c * 2 + 0] = hn::ReduceSum(df, c20) + addon0;
|
tile_c[stride_c * 2 + 0] = scale * hn::ReduceSum(df, c20) + addon0;
|
||||||
tile_c[stride_c * 2 + 1] = hn::ReduceSum(df, c21) + addon1;
|
tile_c[stride_c * 2 + 1] = scale * hn::ReduceSum(df, c21) + addon1;
|
||||||
tile_c[stride_c * 2 + 2] = hn::ReduceSum(df, c22) + addon2;
|
tile_c[stride_c * 2 + 2] = scale * hn::ReduceSum(df, c22) + addon2;
|
||||||
tile_c[stride_c * 2 + 3] = hn::ReduceSum(df, c23) + addon3;
|
tile_c[stride_c * 2 + 3] = scale * hn::ReduceSum(df, c23) + addon3;
|
||||||
if (kNumRows == 3) return;
|
if (kNumRows == 3) return;
|
||||||
|
|
||||||
tile_c[stride_c * 3 + 0] = hn::ReduceSum(df, c30) + addon0;
|
tile_c[stride_c * 3 + 0] = scale * hn::ReduceSum(df, c30) + addon0;
|
||||||
tile_c[stride_c * 3 + 1] = hn::ReduceSum(df, c31) + addon1;
|
tile_c[stride_c * 3 + 1] = scale * hn::ReduceSum(df, c31) + addon1;
|
||||||
tile_c[stride_c * 3 + 2] = hn::ReduceSum(df, c32) + addon2;
|
tile_c[stride_c * 3 + 2] = scale * hn::ReduceSum(df, c32) + addon2;
|
||||||
tile_c[stride_c * 3 + 3] = hn::ReduceSum(df, c33) + addon3;
|
tile_c[stride_c * 3 + 3] = scale * hn::ReduceSum(df, c33) + addon3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
|
// Wrapper around StoreHorizontalSums and StoreHorizontalSumsAdd to shorten call
|
||||||
|
|
@ -188,16 +192,16 @@ template <bool kAdd, size_t kNumRows, class DF, class VF = hn::Vec<DF>>
|
||||||
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||||
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
|
DF df, VF c00, VF c01, VF c02, VF c03, VF c10, VF c11, VF c12, VF c13,
|
||||||
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
|
VF c20, VF c21, VF c22, VF c23, VF c30, VF c31, VF c32, VF c33,
|
||||||
const float* HWY_RESTRICT add, size_t add_offset,
|
const float* HWY_RESTRICT add, size_t add_offset, const float scale,
|
||||||
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
float* HWY_RESTRICT tile_c, size_t stride_c) {
|
||||||
if constexpr (kAdd) {
|
if constexpr (kAdd) {
|
||||||
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
StoreHorizontalSumsAdd<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||||
add + add_offset, tile_c, stride_c);
|
add + add_offset, scale, tile_c, stride_c);
|
||||||
} else {
|
} else {
|
||||||
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
StoreHorizontalSums<kNumRows>(df, c00, c01, c02, c03, c10, c11, c12, c13,
|
||||||
c20, c21, c22, c23, c30, c31, c32, c33,
|
c20, c21, c22, c23, c30, c31, c32, c33,
|
||||||
tile_c, stride_c);
|
scale, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -215,7 +219,9 @@ HWY_INLINE void StoreHorizontalSumsMaybeAdd(
|
||||||
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
|
template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
||||||
const hwy::bfloat16_t* HWY_RESTRICT B,
|
const hwy::bfloat16_t* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const float* add,
|
float* HWY_RESTRICT C,
|
||||||
|
const float scale,
|
||||||
|
const float* HWY_RESTRICT add,
|
||||||
const size_t idx_tile, const size_t xtiles,
|
const size_t idx_tile, const size_t xtiles,
|
||||||
const size_t stride_a, const size_t stride_b,
|
const size_t stride_a, const size_t stride_b,
|
||||||
const size_t stride_c) {
|
const size_t stride_c) {
|
||||||
|
|
@ -303,7 +309,7 @@ HWY_INLINE void GEMM_4x4_Tile(const hwy::bfloat16_t* HWY_RESTRICT A,
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||||
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
df, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||||
c32, c33, add, row_b_col_c, tile_c, stride_c);
|
c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
#endif // GEMMA_NATIVE_BF16
|
#endif // GEMMA_NATIVE_BF16
|
||||||
|
|
@ -332,7 +338,9 @@ template <size_t kNumRows, size_t kColsA_RowsB, bool kAdd, typename MatTA,
|
||||||
typename MatTB>
|
typename MatTB>
|
||||||
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
const MatTB* HWY_RESTRICT B,
|
const MatTB* HWY_RESTRICT B,
|
||||||
float* HWY_RESTRICT C, const float* add,
|
float* HWY_RESTRICT C,
|
||||||
|
const float scale,
|
||||||
|
const float* HWY_RESTRICT add,
|
||||||
const size_t idx_tile, const size_t xtiles,
|
const size_t idx_tile, const size_t xtiles,
|
||||||
const size_t stride_a, const size_t stride_b,
|
const size_t stride_a, const size_t stride_b,
|
||||||
const size_t stride_c) {
|
const size_t stride_c) {
|
||||||
|
|
@ -417,18 +425,28 @@ HWY_INLINE void GEMM_4x4_Tile(const MatTA* HWY_RESTRICT A,
|
||||||
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
float* HWY_RESTRICT tile_c = C + stride_c * row_a + row_b_col_c;
|
||||||
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
StoreHorizontalSumsMaybeAdd<kAdd, kNumRows>(
|
||||||
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
d32, c00, c01, c02, c03, c10, c11, c12, c13, c20, c21, c22, c23, c30, c31,
|
||||||
c32, c33, add, row_b_col_c, tile_c, stride_c);
|
c32, c33, add, row_b_col_c, scale, tile_c, stride_c);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// C = A * B * scale [+ add].
|
||||||
|
// Computes the matrix product of A and B and stores this in C.
|
||||||
|
// If kAdd is true, the row-vector `add` is added to each row of C.
|
||||||
|
// A is a matrix of size (batch_size, kColsA_RowsB).
|
||||||
|
// B is passed transposed (column-major), so a matrix of size
|
||||||
|
// (kColsBC, kColsA_RowsB), representing a B of size (kColsA_RowsB, kColsBC).
|
||||||
|
// C is a matrix of size (batch_size, kColsBC).
|
||||||
|
// The product is scaled by `scale` to support CompressedArray with scale != 1,
|
||||||
|
// the caller can pass the product of the scales of A and B.
|
||||||
|
// A scale for `add` is not supported, so make sure its scale is 1.
|
||||||
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
// Tiled 4x4 GEMM. Typically batch_size is 1..512, kColsA_RowsB is 3k or 24k,
|
||||||
// and kColsBC is 24k or 3k. Note: B is transposed (column-major).
|
// and kColsBC is 24k or 3k.
|
||||||
// NOTE that batch_size is the number of rows of A and C.
|
|
||||||
// This function processes tiles in parallel with a work-stealing thread pool.
|
// This function processes tiles in parallel with a work-stealing thread pool.
|
||||||
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
template <size_t kColsA_RowsB, size_t kColsBC, bool kAdd, typename MatTA,
|
||||||
typename MatTB, typename OutT, typename AddT>
|
typename MatTB, typename OutT, typename AddT>
|
||||||
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
||||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||||
OutT* HWY_RESTRICT C, const AddT* add, hwy::ThreadPool& pool) {
|
float scale, OutT* HWY_RESTRICT C, const AddT* HWY_RESTRICT add,
|
||||||
|
hwy::ThreadPool& pool) {
|
||||||
PROFILER_ZONE("Matmul");
|
PROFILER_ZONE("Matmul");
|
||||||
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
// Process reg-sized tiles of C in parallel. We currently write C directly,
|
||||||
// which touches more memory than fits in L3. TODO: add another level of loops
|
// which touches more memory than fits in L3. TODO: add another level of loops
|
||||||
|
|
@ -454,31 +472,36 @@ HWY_NOINLINE void MatMul_4x4_Batch_Add(
|
||||||
HWY_ASSERT(num_rows > 0);
|
HWY_ASSERT(num_rows > 0);
|
||||||
switch (num_rows) {
|
switch (num_rows) {
|
||||||
case 1:
|
case 1:
|
||||||
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
GEMM_4x4_Tile<1, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||||
kStrideA, kStrideB, kStrideC);
|
kTilesX, kStrideA, kStrideB,
|
||||||
|
kStrideC);
|
||||||
break;
|
break;
|
||||||
case 2:
|
case 2:
|
||||||
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
GEMM_4x4_Tile<2, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||||
kStrideA, kStrideB, kStrideC);
|
kTilesX, kStrideA, kStrideB,
|
||||||
|
kStrideC);
|
||||||
break;
|
break;
|
||||||
case 3:
|
case 3:
|
||||||
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
GEMM_4x4_Tile<3, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||||
kStrideA, kStrideB, kStrideC);
|
kTilesX, kStrideA, kStrideB,
|
||||||
|
kStrideC);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, add, idx_tile, kTilesX,
|
GEMM_4x4_Tile<4, kColsA_RowsB, kAdd>(A, B, C, scale, add, idx_tile,
|
||||||
kStrideA, kStrideB, kStrideC);
|
kTilesX, kStrideA, kStrideB,
|
||||||
|
kStrideC);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// As above, without the add.
|
||||||
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
template <size_t kColsA_RowsB, size_t kColsBC, typename MatTA,
|
||||||
typename MatTB, typename OutT>
|
typename MatTB, typename OutT>
|
||||||
HWY_NOINLINE void MatMul_4x4_Batch(
|
HWY_NOINLINE void MatMul_4x4_Batch(
|
||||||
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
size_t batch_size, const MatTA* HWY_RESTRICT A, const MatTB* HWY_RESTRICT B,
|
||||||
OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
float scale, OutT* HWY_RESTRICT C, hwy::ThreadPool& pool) {
|
||||||
MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>(
|
MatMul_4x4_Batch_Add<kColsA_RowsB, kColsBC, /*kAdd=*/false>(
|
||||||
batch_size, A, B, C, /*add=*/static_cast<OutT*>(nullptr), pool);
|
batch_size, A, B, scale, C, /*add=*/static_cast<OutT*>(nullptr), pool);
|
||||||
}
|
}
|
||||||
|
|
||||||
//------------------------------------------------------------------------------
|
//------------------------------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -32,10 +32,11 @@
|
||||||
#include "hwy/aligned_allocator.h"
|
#include "hwy/aligned_allocator.h"
|
||||||
#include "hwy/base.h"
|
#include "hwy/base.h"
|
||||||
#include "hwy/contrib/thread_pool/thread_pool.h"
|
#include "hwy/contrib/thread_pool/thread_pool.h"
|
||||||
|
#include "hwy/timer.h"
|
||||||
|
|
||||||
// clang-format off
|
// clang-format off
|
||||||
#undef HWY_TARGET_INCLUDE
|
#undef HWY_TARGET_INCLUDE
|
||||||
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" //NOLINT
|
#define HWY_TARGET_INCLUDE "gemma/ops_test.cc" // NOLINT
|
||||||
// clang-format on
|
// clang-format on
|
||||||
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
#include "hwy/foreach_target.h" // IWYU pragma: keep
|
||||||
#include "hwy/highway.h"
|
#include "hwy/highway.h"
|
||||||
|
|
@ -362,7 +363,7 @@ CompressedArray<MatT, kOuter * kInner> GenerateMat(size_t offset,
|
||||||
});
|
});
|
||||||
|
|
||||||
Compress(content, ws, mat, pool);
|
Compress(content, ws, mat, pool);
|
||||||
mat.set_scale(1.0f);
|
mat.set_scale(1.9f); // Arbitrary value, different from 1.
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -377,7 +378,7 @@ CompressedArray<MatT, kOuter * kInner> GenerateZeroMat(hwy::ThreadPool& pool) {
|
||||||
});
|
});
|
||||||
|
|
||||||
Compress(content, ws, mat, pool);
|
Compress(content, ws, mat, pool);
|
||||||
mat.set_scale(1.0f);
|
mat.set_scale(1.2f); // Arbitrary value, different from 1.
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -400,7 +401,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateMatHeap(
|
||||||
|
|
||||||
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
||||||
pool);
|
pool);
|
||||||
mat->set_scale(1.0f);
|
mat->set_scale(0.6f); // Arbitrary value, different from 1.
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -423,7 +424,8 @@ GenerateTransposeMatHeap(size_t offset, hwy::ThreadPool& pool) {
|
||||||
|
|
||||||
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
||||||
pool);
|
pool);
|
||||||
mat->set_scale(1.0f);
|
// Arbitrary value, different from 1, must match GenerateMatHeap.
|
||||||
|
mat->set_scale(0.6f);
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -443,7 +445,7 @@ std::unique_ptr<CompressedArray<MatT, kOuter * kInner>> GenerateZeroMatHeap(
|
||||||
|
|
||||||
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
Compress(content.get(), kOuter * kInner, ws, kOuter * kInner, mat->data(), 0,
|
||||||
pool);
|
pool);
|
||||||
mat->set_scale(1.0f);
|
mat->set_scale(1.2f); // Arbitrary value, different from 1.
|
||||||
return mat;
|
return mat;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -487,6 +489,7 @@ hwy::AlignedFreeUniquePtr<float[]> SimpleMatVecAdd(
|
||||||
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
hwy::AlignedFreeUniquePtr<float[]> out = hwy::AllocateAligned<float>(kOuter);
|
||||||
HWY_ASSERT(uncompressed_mat && out);
|
HWY_ASSERT(uncompressed_mat && out);
|
||||||
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
Decompress(mat, 0, uncompressed_mat.get(), kOuter * kInner);
|
||||||
|
MulByConst(mat.scale(), uncompressed_mat.get(), kOuter * kInner);
|
||||||
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
for (size_t idx_row = 0; idx_row < kOuter; idx_row++) {
|
||||||
out[idx_row] = add[idx_row];
|
out[idx_row] = add[idx_row];
|
||||||
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
for (size_t idx_col = 0; idx_col < kInner; idx_col++) {
|
||||||
|
|
@ -513,14 +516,14 @@ void AssertClose(const hwy::AlignedFreeUniquePtr<float[]>& a,
|
||||||
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||||
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
HWY_IF_T_SIZE_GT(MatTB, 1)>
|
||||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
const MatTB* HWY_RESTRICT b, const float* add,
|
const MatTB* HWY_RESTRICT b, const float scale,
|
||||||
float* HWY_RESTRICT out) {
|
const float* add, float* HWY_RESTRICT out) {
|
||||||
for (size_t i = 0; i < batch_size; ++i) {
|
for (size_t i = 0; i < batch_size; ++i) {
|
||||||
for (size_t k = 0; k < kN; ++k) {
|
for (size_t k = 0; k < kN; ++k) {
|
||||||
for (size_t j = 0; j < kK; ++j) {
|
for (size_t j = 0; j < kK; ++j) {
|
||||||
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
const float a1 = hwy::ConvertScalarTo<float>(a[i * kN + k]);
|
||||||
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
const float b1 = hwy::ConvertScalarTo<float>(b[k * kK + j]);
|
||||||
out[i * kK + j] += a1 * b1;
|
out[i * kK + j] += scale * a1 * b1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if (add != nullptr) {
|
if (add != nullptr) {
|
||||||
|
|
@ -537,12 +540,13 @@ template <size_t kN, size_t kK, typename MatTA, typename MatTB,
|
||||||
HWY_IF_T_SIZE(MatTB, 1)>
|
HWY_IF_T_SIZE(MatTB, 1)>
|
||||||
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
HWY_INLINE void MatMulSlowBatch(size_t batch_size, const MatTA* HWY_RESTRICT a,
|
||||||
const MatTB* HWY_RESTRICT b_compr,
|
const MatTB* HWY_RESTRICT b_compr,
|
||||||
const float* add, float* HWY_RESTRICT out) {
|
const float scale, const float* add,
|
||||||
|
float* HWY_RESTRICT out) {
|
||||||
const hn::ScalableTag<float> d;
|
const hn::ScalableTag<float> d;
|
||||||
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
hwy::AlignedFreeUniquePtr<float[]> b = hwy::AllocateAligned<float>(kK * kN);
|
||||||
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
|
CompressTraits<MatTB>::Decompress(d, /*in_capacity=*/0, b_compr, 0, b.get(),
|
||||||
kK * kN);
|
kK * kN);
|
||||||
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), add, out);
|
MatMulSlowBatch<kN, kK>(batch_size, a, b.get(), scale, add, out);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
template <size_t kM, size_t kN, size_t kK, bool kAdd, typename MatTA,
|
||||||
|
|
@ -558,11 +562,13 @@ void TestTiledBatchMatMul() {
|
||||||
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
GenerateMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
std::unique_ptr<CompressedArray<float, kK>> add =
|
std::unique_ptr<CompressedArray<float, kK>> add =
|
||||||
GenerateMatHeap<float, 1, kK>(0, pool);
|
GenerateMatHeap<float, 1, kK>(0, pool);
|
||||||
|
add->set_scale(1.0f);
|
||||||
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
std::unique_ptr<CompressedArray<float, kM * kK>> c_slow =
|
||||||
GenerateZeroMatHeap<float, kM, kK>(pool);
|
GenerateZeroMatHeap<float, kM, kK>(pool);
|
||||||
|
const float scale = a->scale() * b->scale();
|
||||||
|
|
||||||
const double start_slow = hwy::platform::Now();
|
const double start_slow = hwy::platform::Now();
|
||||||
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(),
|
MatMulSlowBatch<kN, kK>(kM, a->data(), b->data(), scale,
|
||||||
kAdd ? add->data() : nullptr, c_slow->data());
|
kAdd ? add->data() : nullptr, c_slow->data());
|
||||||
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
|
const double slow_matmul_seconds = hwy::platform::Now() - start_slow;
|
||||||
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
|
fprintf(stderr, "MatMulSlowBatch took %f seconds.\n", slow_matmul_seconds);
|
||||||
|
|
@ -572,11 +578,13 @@ void TestTiledBatchMatMul() {
|
||||||
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
GenerateTransposeMatHeap<MatTB, kN, kK>(0, pool);
|
||||||
|
|
||||||
const double start_tiled = hwy::platform::Now();
|
const double start_tiled = hwy::platform::Now();
|
||||||
|
EXPECT_EQ(scale, a->scale() * b_trans->scale());
|
||||||
if (kAdd) {
|
if (kAdd) {
|
||||||
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), c.get(),
|
MatMul_4x4_Batch_Add<kN, kK, kAdd>(kM, a->data(), b_trans->data(), scale,
|
||||||
add->data(), pool);
|
c.get(), add->data(), pool);
|
||||||
} else {
|
} else {
|
||||||
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), c.get(), pool);
|
MatMul_4x4_Batch<kN, kK>(kM, a->data(), b_trans->data(), scale, c.get(),
|
||||||
|
pool);
|
||||||
}
|
}
|
||||||
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
const double tiled_matmul_seconds = hwy::platform::Now() - start_tiled;
|
||||||
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);
|
fprintf(stderr, "MatMul_4x4_Batch took %f seconds.\n", tiled_matmul_seconds);
|
||||||
|
|
|
||||||
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
|
|
||||||
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
|
|
||||||
#include "compression/compress.h"
|
#include "compression/compress.h"
|
||||||
|
|
@ -96,6 +97,9 @@ class WeightLogger {
|
||||||
public:
|
public:
|
||||||
template <size_t N>
|
template <size_t N>
|
||||||
void operator()(const char* name, const CompressedArray<float, N>& tensor) {
|
void operator()(const char* name, const CompressedArray<float, N>& tensor) {
|
||||||
|
if (tensor.scale() != 1.0f) {
|
||||||
|
printf("[scale=%f] ", tensor.scale());
|
||||||
|
}
|
||||||
LogVec(name, tensor.data(), N);
|
LogVec(name, tensor.data(), N);
|
||||||
total_weights += N;
|
total_weights += N;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue