1.3x prefill, 0.95x decode: matmul replacing last matvec

Before 38.28, 9.17 (with profiler enabled, prompt = 330 tok)
```
Gen.FFW                                 :      15414 x         4692352 = 24.166318
Gen.Attention.SumHeads                  :      15414 x         1394804 =  7.183451 !!
Gen.Embedding                           :        361 x        49961894 =  6.026297
Gen.Attention.QKV                       :      15414 x         1005125 =  5.176546
Gen.Attention.DotSoftmax                :      15414 x          885480 =  4.560357
RopeAndMulBy                            :     696528 x           11867 =  2.761818
```

After 49.80, 8.68
```
Gen.FFW                                 :      14448 x         5312783 = 25.646868
Gen.Embedding                           :        338 x        63044815 =  7.119845
Gen.Attention.QKV                       :      14448 x         1115003 =  5.382557
Gen.Attention.DotSoftmax                :      14448 x          897577 =  4.332957
RopeAndMulBy                            :     673344 x           11886 =  2.674156
Gen.Attention.SumHeads                  :      14448 x          518291 =  2.501993 !!
```
PiperOrigin-RevId: 662024085
This commit is contained in:
Jan Wassenberg 2024-08-12 03:35:30 -07:00 committed by Copybara-Service
parent 282f73ec2f
commit b831fa8482
6 changed files with 68 additions and 40 deletions

View File

@ -126,6 +126,8 @@ TEST(OptimizeTest, GradientDescent) {
info.model, prompt, gemma.Weights(), forward, inv_timescale, pool);
CrossEntropyLossBackwardPass(info.model, prompt, gemma.Weights(), forward,
grad, backward, inv_timescale, pool);
CallForModelAndWeight<ReshapeCompressedWeights>(
info.model, info.weight, gemma.MutableWeights(), pool);
num_ok += verify(prompt) ? 1 : 0;
}
total_loss /= kBatchSize;

View File

@ -74,10 +74,8 @@ struct Activations {
RowVectorBatch<float> pre_att_rms_out;
RowVectorBatch<float> att; // attention vector
RowVectorBatch<float> att_out; // attention output
// After linear transformation, shared by all heads
RowVectorBatch<float> att_post1;
// Accumulation of attention outputs over heads
RowVectorBatch<float> att_post2;
RowVectorBatch<float> att_sums;
// Gated FFW
RowVectorBatch<hwy::bfloat16_t> bf_pre_ffw_rms_out;
@ -144,8 +142,7 @@ struct Activations {
pre_att_rms_out = RowVectorBatch<float>(batch_size, kModelDim);
att = RowVectorBatch<float>(batch_size, kHeads * kSeqLen);
att_out = RowVectorBatch<float>(batch_size, kHeads * kQKVDim);
att_post1 = RowVectorBatch<float>(1, kModelDim);
att_post2 = RowVectorBatch<float>(batch_size, kModelDim);
att_sums = RowVectorBatch<float>(batch_size, kModelDim);
bf_pre_ffw_rms_out = RowVectorBatch<hwy::bfloat16_t>(batch_size, kModelDim);
C1 = RowVectorBatch<float>(batch_size, kFFHiddenDim);

View File

@ -28,7 +28,6 @@
#include <stdio.h>
#include <algorithm> // std::min
#include <memory> // std::unique_ptr
#include <string>
#include <type_traits>
#include <vector>
@ -188,7 +187,7 @@ HWY_NOINLINE void GriffinRecurrent(
// Final linear layer.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
float* HWY_RESTRICT x = activations.griffin_x.Batch(batch_idx);
float* out_ptr = activations.att_post2.Batch(batch_idx);
float* out_ptr = activations.att_sums.Batch(batch_idx);
MatVecAdd<kModelDim, kModelDim>(
layer_weights->griffin.linear_out_w, 0, x,
layer_weights->griffin.linear_out_biases.data_scale1(),
@ -421,39 +420,23 @@ class GemmaAttention {
});
}
// Sums encoded (`att_out`) over num_heads and head_dim (kQKVDim) into output
// (`layer_out`). Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
// Sums encoded (`att_out`) over num_heads (`kHeads`) and head_dim (`kQKVDim`)
// into output (`layer_out`).
HWY_NOINLINE void SumHeads(const size_t num_interleaved) {
PROFILER_ZONE("Gen.Attention.SumHeads");
for (size_t interleaved_idx = 0; interleaved_idx < num_interleaved;
++interleaved_idx) {
// TODO(szabadka) Use a single MatVecAdd like in GriffinRecurrent() after
// rearranging the weights.
float* HWY_RESTRICT att_out = activations_.att_out.Batch(interleaved_idx);
float* HWY_RESTRICT layer_out =
activations_.att_post2.Batch(interleaved_idx);
// Head 0 (and potentially biases) -> layer_out.
// attn_vec_einsum_w has shape [kHeads, kQKVDim, kModelDim].
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
const float* bias =
kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr;
MatVecT<kAdd, kModelDim, kQKVDim>(
layer_weights_.attn_vec_einsum_w, 0, att_out, bias,
activations_.even_odd.All(), layer_out, pool_);
// Head 1 and following are added to layer_out.
for (size_t head = 1; head < kHeads; ++head) {
// NOTE: this is a single kModelDim temp output. If parallelized or
// using MatMul, add per-thread storage.
float* HWY_RESTRICT head_out = activations_.att_post1.All();
// TODO: requires MatMul support for offsets.
MatVec<kModelDim, kQKVDim>(
layer_weights_.attn_vec_einsum_w, head * kModelDim * kQKVDim,
att_out + head * kQKVDim, activations_.even_odd.All(), head_out,
pool_);
AddFrom(head_out, layer_out, kModelDim);
}
}
constexpr bool kAdd = TConfig::kSoftmaxAttnOutputBiases;
const float* bias =
kAdd ? layer_weights_.attention_output_biases.data_scale1() : nullptr;
// att_weights and att_out are concatenated heads, each of length kQKVDim.
// Thus the [num_interleaved, kModelDim] matmul output is the sum over
// heads. Compare gemma/modules.py:
// attn_output = self.attn_vec_einsum('BTNH,NHD->BTD', encoded)
MatMul_4x4<kAdd>(
num_interleaved, MakeMat(activations_.att_out.All(), kHeads * kQKVDim),
MakeMat(layer_weights_.att_weights.data(), kHeads * kQKVDim),
layer_weights_.attn_vec_einsum_w.scale(), bias,
MakeMat(activations_.att_sums.All(), kModelDim), pool_);
}
public:
@ -634,9 +617,9 @@ HWY_NOINLINE void TransformerLayer(
activations, layer_weights, div_seq_len, kv_caches, pool);
PostNorm<TConfig>(num_interleaved, layer_weights->post_attention_norm_scale,
activations.att_post2.All());
activations.att_sums.All());
ResidualConnection<TConfig>(num_interleaved, activations.att_post2.All(),
ResidualConnection<TConfig>(num_interleaved, activations.att_sums.All(),
activations.x.All(), layer_weights,
/*is_attention=*/true);

View File

@ -158,6 +158,7 @@ class Gemma {
const ModelInfo& Info() const { return info_; }
const GemmaTokenizer& Tokenizer() const { return tokenizer_; }
const ByteStorageT& Weights() const { return weights_u8_; }
ByteStorageT& MutableWeights() { return weights_u8_; }
void Generate(const RuntimeConfig& runtime_config, const PromptTokens& prompt,
size_t start_pos, KVCache& kv_cache, TimingInfo& timing_info);

View File

@ -72,6 +72,7 @@ struct LoadCompressedWeightsT {
}
HWY_ASSERT(scale_pos == TConfig::kNumTensorScales);
}
c_weights->Reshape();
return c_weights_u8;
}
};

View File

@ -24,6 +24,7 @@
#include "hwy/aligned_allocator.h"
#include "hwy/base.h"
#include "hwy/contrib/thread_pool/thread_pool.h"
#include "hwy/profiler.h"
namespace gcpp {
@ -93,6 +94,33 @@ struct CompressedLayer {
ArrayT<float, kFFBiases ? 2 * kFFHiddenDim : 0> ffw_gating_biases;
ArrayT<float, kFFBiases ? kModelDim : 0> ffw_output_biases;
// Reshaped attention; not loaded from disk via ForEachTensor.
ArrayT<Weight, kModelDim * kHeads * kQKVDim> att_weights;
// Initializes att_weights from attn_vec_einsum_w, hence this must be called
// after loading weights via ForEachTensor.
// TODO: update compression/convert_weights to bake this in.
void Reshape() {
PROFILER_ZONE("Startup.Reshape");
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kQKVDim = TConfig::kQKVDim;
// Would have to implement a CompressTraits::Copy for NUQ.
static_assert(!hwy::IsSame<Weight, NuqStream>());
// Reshape [kHeads, kModelDim, kQKVDim] to [kModelDim, kHeads * kQKVDim].
for (size_t m = 0; m < kModelDim; ++m) {
Weight* HWY_RESTRICT out_row = att_weights.data() + m * kHeads * kQKVDim;
for (size_t h = 0; h < kHeads; ++h) {
hwy::CopyBytes(
attn_vec_einsum_w.data() + h * kModelDim * kQKVDim + m * kQKVDim,
out_row + h * kQKVDim, kQKVDim * sizeof(Weight));
}
}
}
};
// Array instead of single large allocation for parallel mem init. Split out
@ -135,6 +163,13 @@ struct CompressedWeights {
explicit CompressedWeights(hwy::ThreadPool& pool) : c_layer_ptrs(pool) {}
// Called by weights.cc after ForEachTensor.
void Reshape() {
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
GetLayer(layer)->Reshape();
}
}
void ZeroInit() {
hwy::ZeroBytes(&embedder_input_embedding, sizeof(embedder_input_embedding));
hwy::ZeroBytes(&final_norm_scale, sizeof(final_norm_scale));
@ -174,6 +209,15 @@ struct ZeroInitCompressedWeights {
}
};
template <typename TConfig>
struct ReshapeCompressedWeights {
void operator()(ByteStorageT& weights_u8, hwy::ThreadPool& pool) const {
CompressedWeights<TConfig>& weights =
*reinterpret_cast<CompressedWeights<TConfig>*>(weights_u8.get());
weights.Reshape();
}
};
// TODO: also add RandInitCompressedWeights
template <class TConfig>