Code cleanup

- Simplify template arg list, enable deduction
- missing hn:: on " Lanes"
- 1.0f suffix
- move RMSNormBatched into ops.h
- static constexpr -> constexpr
- concrete type instead of LayerT, WeightArrayT
- inline GetWeights
- remove if (runtime_config.verbosity
- merge AllocatePrefill and AllocateDecode
- remove bf_ffw_hidden

PiperOrigin-RevId: 644931277
This commit is contained in:
Jan Wassenberg 2024-06-20 01:09:39 -07:00 committed by Copybara-Service
parent 658fb3e506
commit 48ebba8b7a
2 changed files with 95 additions and 112 deletions

View File

@ -78,6 +78,7 @@ struct Activations {
static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1);
std::array<float, kBatchSize * kModelDim> x; // input
std::array<float, kBatchSize * kModelDim> pre_att_rms_out;
std::array<float, kBatchSize * kHeads * kQStride> q; // query vector
std::array<float, kBatchSize * kHeads * TConfig::kSeqLen>
@ -87,17 +88,13 @@ struct Activations {
att_post1; // attention output after linear transformation, per head
std::array<float, kBatchSize * kModelDim>
att_post2; // accumulation of attention outputs over heads
std::array<hwy::bfloat16_t, kBatchSize * kModelDim> bf_pre_ffw_rms_out;
std::array<float, kBatchSize * TConfig::kFFHiddenDim * 2> ffw_hidden;
// For FFW MatMul.
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1;
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C1; // MatMul output
std::array<float, kBatchSize * TConfig::kFFHiddenDim> C2;
// bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved.
// std::array<hwy::bfloat16_t, kBatchSize * 2 * TConfig::kFFHiddenDim>
// bf_ffw_hidden;
std::array<float, kBatchSize * kModelDim> ffw_out;
std::array<float, kBatchSize * TConfig::kVocabSize> logits;
// For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into
@ -234,19 +231,19 @@ namespace gcpp {
namespace HWY_NAMESPACE {
namespace {
template <size_t kBatchSize, typename LayerT, class TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void GriffinRecurrent(
size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations, const LayerT* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) {
Activations<TConfig, kBatchSize>& activations,
const CompressedLayer<TConfig>* layer_weights, KVCache& kv_cache,
hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Griffin");
namespace hn = hwy::HWY_NAMESPACE;
using D = hn::ScalableTag<float>;
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim =
Activations<TConfig, kBatchSize>::kModelDim;
static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
static constexpr size_t kHeads = TConfig::kHeads;
constexpr size_t kModelDim = Activations<TConfig, kBatchSize>::kModelDim;
constexpr size_t kConv1dWidth = TConfig::kConv1dWidth;
constexpr size_t kHeads = TConfig::kHeads;
// X / Y linear layers.
for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) {
@ -268,7 +265,7 @@ HWY_NOINLINE void GriffinRecurrent(
const size_t pos = batch_start + batch_idx;
float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset;
HWY_FULL(float) df;
HWY_DASSERT(kModelDim % Lanes(df) == 0);
HWY_DASSERT(kModelDim % hn::Lanes(df) == 0);
const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1);
// cache[i] = input at time t-i.
@ -279,7 +276,7 @@ HWY_NOINLINE void GriffinRecurrent(
kv_cache.conv1d_cache.get() + layer_offset +
((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim;
}
for (size_t i = 0; i < kModelDim; i += Lanes(df)) {
for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) {
auto xv = hn::Load(df, x + i);
auto accum0 =
hn::Load(df, layer_weights->griffin.conv_biases.data() + i);
@ -332,15 +329,15 @@ HWY_NOINLINE void GriffinRecurrent(
fn_mul);
// RNN scan
HWY_FULL(float) df;
HWY_DASSERT(kHeadDim % Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += Lanes(df)) {
HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0);
for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) {
auto log_a = hn::Load(df, a + head_offset + i);
auto gated_x = hn::Load(df, x + head_offset + i);
auto rnn = hn::Load(df, rnn_state + head_offset + i);
auto a = hn::Exp(df, log_a);
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0)));
auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f)));
if (pos == 0) {
x_multiplier = hn::Set(df, 1.0);
x_multiplier = hn::Set(df, 1.0f);
}
auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn));
hn::Store(new_x, df, rnn_state + head_offset + i);
@ -365,11 +362,11 @@ HWY_NOINLINE void GriffinRecurrent(
}
}
template <size_t kBatchSize, typename LayerT, class TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
Activations<TConfig, kBatchSize>& activations,
const LayerT* layer_weights, KVCache& kv_cache,
hwy::ThreadPool& pool) {
const CompressedLayer<TConfig>* layer_weights,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Attention");
HWY_DASSERT(num_tokens <= kBatchSize);
using TActivations = Activations<TConfig, kBatchSize>;
@ -429,7 +426,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
static_assert((kHeads % kKVHeads) == 0,
"query heads must be a multiple of key-value heads");
static constexpr size_t kGroupHeads = kHeads / kKVHeads;
constexpr size_t kGroupHeads = kHeads / kKVHeads;
pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR {
const size_t head = task % kHeads;
const size_t batch_idx = task / kHeads;
@ -494,13 +491,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer,
}
}
template <size_t kBatchSize, typename LayerT, typename TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
size_t num_tokens, const LayerT* layer_weights,
size_t num_tokens,
const CompressedLayer<TConfig>* layer_weights,
hwy::ThreadPool& pool) {
HWY_DASSERT(num_tokens <= kBatchSize);
static constexpr size_t kModelDim = TConfig::kModelDim;
static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim;
float* HWY_RESTRICT even_odd = activations.even_odd.data();
// TODO: MatMul does not yet support adding another matrix to the result.
@ -570,42 +568,11 @@ HWY_NOINLINE void FFW(Activations<TConfig, kBatchSize>& activations,
}
}
// The below "batched" versions are just simple loops for now.
template <size_t kBatchSize, typename WeightT, typename OutT>
static void RMSNormBatched(size_t num_tokens, const float* activations,
const WeightT* weights, OutT* out,
const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations + token_idx * model_dim, weights,
out + token_idx * model_dim, model_dim);
}
}
template <size_t kBatchSize, typename WeightT, typename InOutT>
static void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
InOutT* inout, const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
}
}
template <size_t kBatchSize>
static void AddFromBatched(size_t num_tokens, const float* other, float* x,
const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
model_dim);
}
}
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
const WeightArrayT& weights,
const CompressedWeights<TConfig>& weights,
Activations<TConfig, kBatchSize>& activations) {
static constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kModelDim = TConfig::kModelDim;
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>();
HWY_DASSERT(token >= 0);
@ -621,13 +588,13 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos,
};
}
template <size_t kBatchSize, typename LayerWeightArrayT, class TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void TransformerLayer(
size_t num_tokens, size_t pos, size_t layer,
const LayerWeightArrayT* layer_weights,
const CompressedLayer<TConfig>* layer_weights,
Activations<TConfig, kBatchSize>& activations, KVCache& kv_cache,
hwy::ThreadPool& pool) {
static constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kModelDim = TConfig::kModelDim;
auto type = TConfig::kLayerConfig[layer];
size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
@ -635,11 +602,11 @@ HWY_NOINLINE void TransformerLayer(
layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim);
if (type == LayerAttentionType::kGemma) {
Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
layer_weights, kv_cache, pool);
Attention(pos, num_tokens, layer_of_type, activations, layer_weights,
kv_cache, pool);
} else {
GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
layer_weights, kv_cache, pool);
GriffinRecurrent(pos, num_tokens, layer_of_type, activations, layer_weights,
kv_cache, pool);
}
if (TConfig::kPostNormScale) {
RMSNormInplaceBatched<kBatchSize>(
@ -651,7 +618,7 @@ HWY_NOINLINE void TransformerLayer(
RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim);
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
FFW(activations, num_tokens, layer_weights, pool);
if (TConfig::kPostNormScale) {
RMSNormInplaceBatched<kBatchSize>(num_tokens,
layer_weights->post_ffw_norm_scale.data(),
@ -661,9 +628,9 @@ HWY_NOINLINE void TransformerLayer(
activations.x.data(), kModelDim);
}
template <size_t kBatchSize, typename WeightArrayT, typename TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights,
const CompressedWeights<TConfig>& weights,
Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool) {
PROFILER_ZONE("Gen.Prefill\\Att\\FFW");
@ -685,9 +652,9 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
// Compute the transformer for a batch of input tokens. During generation,
// we usually have num_tokens == 1 (and also kBatchSize == 1).
template <size_t kBatchSize, typename WeightArrayT, class TConfig>
template <class TConfig, size_t kBatchSize>
HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights,
const CompressedWeights<TConfig>& weights,
Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool,
const LayersOutputFunc& layers_output) {
@ -698,17 +665,18 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos,
layers_output(pos + token_idx, "Tokens", &token_f, 1);
}
}
static constexpr size_t kModelDim = TConfig::kModelDim;
constexpr size_t kModelDim = TConfig::kModelDim;
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
EmbedToken(tokens[token_idx], token_idx, pos, weights, activations);
}
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
const auto* layer_weights = weights.GetLayer(layer);
const CompressedLayer<TConfig>* layer_weights = weights.GetLayer(layer);
TransformerLayer(num_tokens, pos, layer, layer_weights, activations,
kv_cache, pool);
if (layers_output) {
std::string block_name = "blocks." + std::to_string(layer);
const std::string block_name = "blocks." + std::to_string(layer);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
layers_output(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim, kModelDim);
@ -754,11 +722,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens,
fprintf(stderr, "%zu\n", prompt_size);
}
}
}
template <class TConfig>
const CompressedWeights<TConfig>& GetWeights(const ByteStorageT& weights_u8) {
return *reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
HWY_ASSERT(prompt_size > 0);
}
template <class TConfig, size_t kBatchSize>
@ -776,12 +741,13 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
const RuntimeConfig& runtime_config,
const std::vector<int>& prompt, size_t pos, KVCache& kv_cache,
hwy::ThreadPool& pool, TimingInfo& timing_info) {
const CompressedWeights<TConfig>& weights = GetWeights<TConfig>(weights_u8);
const CompressedWeights<TConfig>& weights =
*reinterpret_cast<const CompressedWeights<TConfig>*>(weights_u8.get());
auto& prefill_activations =
GetActivations<TConfig, kPrefillBatchSize>(prefill_u8);
auto& activations = GetActivations<TConfig, 1>(decode_u8);
static constexpr size_t kVocabSize = TConfig::kVocabSize;
constexpr size_t kVocabSize = TConfig::kVocabSize;
size_t prompt_size = prompt.size();
size_t max_tokens = runtime_config.max_tokens;
size_t max_generated_tokens = runtime_config.max_generated_tokens;
@ -791,7 +757,6 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
max_tokens);
return;
}
HWY_ASSERT(prompt_size > 0);
// If no sample_func is provided, we use top-k sampling.
const SampleFunc sample_token =
@ -825,8 +790,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
HWY_DASSERT(batch_size <= kPrefillBatchSize);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const int* batch_tokens = prompt.data() + pos_offset;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool);
Prefill(batch_tokens, batch_size, pos, weights, prefill_activations,
kv_cache, pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return;
}
@ -834,11 +799,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
pos_offset += batch_size;
}
if (runtime_config.verbosity >= 2) {
const double prefill_end = hwy::platform::Now();
timing_info.prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
}
timing_info.prefill_tok_sec =
static_cast<double>(pos_offset) / (hwy::platform::Now() - prefill_start);
// Start generation.
const double gen_start = hwy::platform::Now();
@ -851,9 +813,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) {
Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
activations, kv_cache, pool,
runtime_config.layers_output);
Transformer(&token, kDecodeBatchSize, pos, weights, activations, kv_cache,
pool, runtime_config.layers_output);
float token_logit = 0.0f;
// The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill
@ -885,11 +846,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
break;
}
}
if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
}
timing_info.gen_tok_sec = static_cast<double>(pos_offset - pos_gen_start) /
(hwy::platform::Now() - gen_start);
}
} // namespace HWY_NAMESPACE
@ -901,18 +859,13 @@ namespace gcpp {
namespace {
template <typename TConfig>
struct AllocatePrefill {
ByteStorageT operator()() const {
return AllocateSizeof<Activations<TConfig, kPrefillBatchSize>>();
struct AllocateState {
void operator()(ByteStorageT& prefill, ByteStorageT& decode) const {
prefill = AllocateSizeof<Activations<TConfig, kPrefillBatchSize>>();
decode = AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
}
};
template <typename TConfig>
struct AllocateDecode {
ByteStorageT operator()() const {
return AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
}
};
} // namespace
Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
@ -922,8 +875,8 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type,
model_type_(model_type),
weight_type_(weight_type) {
weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
decode_u8_);
}
Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
@ -935,8 +888,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type,
HWY_ASSERT(weight_type == Type::kF32);
weights_u8_ = CallForModel<float, AllocateCompressedWeights>(
model_type, pool);
prefill_u8_ = CallForModelAndWeight<AllocatePrefill>(model_type, weight_type);
decode_u8_ = CallForModelAndWeight<AllocateDecode>(model_type, weight_type);
CallForModelAndWeight<AllocateState>(model_type, weight_type, prefill_u8_,
decode_u8_);
}
Gemma::~Gemma() {

View File

@ -1629,6 +1629,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom(
HWY_ATTR { return hn::Add(x, other); });
}
// Simple loops unless/until batch sizes are large enough to parallelize.
template <size_t kBatchSize, typename WeightT, typename OutT>
void RMSNormBatched(size_t num_tokens, const float* activations,
const WeightT* weights, OutT* out, const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNorm(activations + token_idx * model_dim, weights,
out + token_idx * model_dim, model_dim);
}
}
template <size_t kBatchSize, typename WeightT, typename InOutT>
void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights,
InOutT* inout, const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
RMSNormInplace(weights, inout + token_idx * model_dim, model_dim);
}
}
template <size_t kBatchSize>
void AddFromBatched(size_t num_tokens, const float* other, float* x,
const size_t model_dim) {
HWY_DASSERT(num_tokens <= kBatchSize);
for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
AddFrom(other + token_idx * model_dim, x + token_idx * model_dim,
model_dim);
}
}
static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other,
float* HWY_RESTRICT x, const size_t size,
const size_t max_pos) {