From 48ebba8b7af3be404791117a72d07f06238b6acf Mon Sep 17 00:00:00 2001 From: Jan Wassenberg Date: Thu, 20 Jun 2024 01:09:39 -0700 Subject: [PATCH] Code cleanup - Simplify template arg list, enable deduction - missing hn:: on " Lanes" - 1.0f suffix - move RMSNormBatched into ops.h - static constexpr -> constexpr - concrete type instead of LayerT, WeightArrayT - inline GetWeights - remove if (runtime_config.verbosity - merge AllocatePrefill and AllocateDecode - remove bf_ffw_hidden PiperOrigin-RevId: 644931277 --- gemma/gemma.cc | 177 ++++++++++++++++++------------------------------- gemma/ops.h | 30 +++++++++ 2 files changed, 95 insertions(+), 112 deletions(-) diff --git a/gemma/gemma.cc b/gemma/gemma.cc index 8ac9dce..4f4e8ca 100644 --- a/gemma/gemma.cc +++ b/gemma/gemma.cc @@ -78,6 +78,7 @@ struct Activations { static constexpr size_t kQStride = kQKVDim * (kIsMHA ? 3 : 1); std::array x; // input + std::array pre_att_rms_out; std::array q; // query vector std::array @@ -87,17 +88,13 @@ struct Activations { att_post1; // attention output after linear transformation, per head std::array att_post2; // accumulation of attention outputs over heads + std::array bf_pre_ffw_rms_out; std::array ffw_hidden; - - // For FFW MatMul. - std::array C1; + std::array C1; // MatMul output std::array C2; - - // bf_ version can't be used until GeluMulToBF16 issue in FFW() is resolved. - // std::array - // bf_ffw_hidden; std::array ffw_out; + std::array logits; // For bf16/f32 vectors * bf16 matrix: faster to unpack once beforehand, into @@ -234,19 +231,19 @@ namespace gcpp { namespace HWY_NAMESPACE { namespace { -template +template HWY_NOINLINE void GriffinRecurrent( size_t batch_start, size_t num_tokens, size_t layer, - Activations& activations, const LayerT* layer_weights, - KVCache& kv_cache, hwy::ThreadPool& pool) { + Activations& activations, + const CompressedLayer* layer_weights, KVCache& kv_cache, + hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Griffin"); namespace hn = hwy::HWY_NAMESPACE; using D = hn::ScalableTag; HWY_DASSERT(num_tokens <= kBatchSize); - static constexpr size_t kModelDim = - Activations::kModelDim; - static constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; - static constexpr size_t kHeads = TConfig::kHeads; + constexpr size_t kModelDim = Activations::kModelDim; + constexpr size_t kConv1dWidth = TConfig::kConv1dWidth; + constexpr size_t kHeads = TConfig::kHeads; // X / Y linear layers. for (size_t batch_idx = 0; batch_idx < num_tokens; ++batch_idx) { @@ -268,7 +265,7 @@ HWY_NOINLINE void GriffinRecurrent( const size_t pos = batch_start + batch_idx; float* HWY_RESTRICT x = activations.griffin_x.data() + batch_offset; HWY_FULL(float) df; - HWY_DASSERT(kModelDim % Lanes(df) == 0); + HWY_DASSERT(kModelDim % hn::Lanes(df) == 0); const size_t layer_offset = layer * kModelDim * (kConv1dWidth - 1); // cache[i] = input at time t-i. @@ -279,7 +276,7 @@ HWY_NOINLINE void GriffinRecurrent( kv_cache.conv1d_cache.get() + layer_offset + ((pos + kConv1dWidth - 1 - i) % (kConv1dWidth - 1)) * kModelDim; } - for (size_t i = 0; i < kModelDim; i += Lanes(df)) { + for (size_t i = 0; i < kModelDim; i += hn::Lanes(df)) { auto xv = hn::Load(df, x + i); auto accum0 = hn::Load(df, layer_weights->griffin.conv_biases.data() + i); @@ -332,15 +329,15 @@ HWY_NOINLINE void GriffinRecurrent( fn_mul); // RNN scan HWY_FULL(float) df; - HWY_DASSERT(kHeadDim % Lanes(df) == 0); - for (size_t i = 0; i < kHeadDim; i += Lanes(df)) { + HWY_DASSERT(kHeadDim % hn::Lanes(df) == 0); + for (size_t i = 0; i < kHeadDim; i += hn::Lanes(df)) { auto log_a = hn::Load(df, a + head_offset + i); auto gated_x = hn::Load(df, x + head_offset + i); auto rnn = hn::Load(df, rnn_state + head_offset + i); auto a = hn::Exp(df, log_a); - auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0))); + auto x_multiplier = hn::Sqrt(hn::NegMulAdd(a, a, hn::Set(df, 1.0f))); if (pos == 0) { - x_multiplier = hn::Set(df, 1.0); + x_multiplier = hn::Set(df, 1.0f); } auto new_x = hn::MulAdd(x_multiplier, gated_x, hn::Mul(a, rnn)); hn::Store(new_x, df, rnn_state + head_offset + i); @@ -365,11 +362,11 @@ HWY_NOINLINE void GriffinRecurrent( } } -template +template HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, Activations& activations, - const LayerT* layer_weights, KVCache& kv_cache, - hwy::ThreadPool& pool) { + const CompressedLayer* layer_weights, + KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Attention"); HWY_DASSERT(num_tokens <= kBatchSize); using TActivations = Activations; @@ -429,7 +426,7 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, static_assert((kHeads % kKVHeads) == 0, "query heads must be a multiple of key-value heads"); - static constexpr size_t kGroupHeads = kHeads / kKVHeads; + constexpr size_t kGroupHeads = kHeads / kKVHeads; pool.Run(0, kHeads * num_tokens, [&](uint64_t task, size_t thread) HWY_ATTR { const size_t head = task % kHeads; const size_t batch_idx = task / kHeads; @@ -494,13 +491,14 @@ HWY_NOINLINE void Attention(size_t batch_start, size_t num_tokens, size_t layer, } } -template +template HWY_NOINLINE void FFW(Activations& activations, - size_t num_tokens, const LayerT* layer_weights, + size_t num_tokens, + const CompressedLayer* layer_weights, hwy::ThreadPool& pool) { HWY_DASSERT(num_tokens <= kBatchSize); - static constexpr size_t kModelDim = TConfig::kModelDim; - static constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; + constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kFFHiddenDim = TConfig::kFFHiddenDim; float* HWY_RESTRICT even_odd = activations.even_odd.data(); // TODO: MatMul does not yet support adding another matrix to the result. @@ -570,42 +568,11 @@ HWY_NOINLINE void FFW(Activations& activations, } } -// The below "batched" versions are just simple loops for now. -template -static void RMSNormBatched(size_t num_tokens, const float* activations, - const WeightT* weights, OutT* out, - const size_t model_dim) { - HWY_DASSERT(num_tokens <= kBatchSize); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNorm(activations + token_idx * model_dim, weights, - out + token_idx * model_dim, model_dim); - } -} - -template -static void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, - InOutT* inout, const size_t model_dim) { - HWY_DASSERT(num_tokens <= kBatchSize); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); - } -} - -template -static void AddFromBatched(size_t num_tokens, const float* other, float* x, - const size_t model_dim) { - HWY_DASSERT(num_tokens <= kBatchSize); - for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { - AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, - model_dim); - } -} - -template +template HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, - const WeightArrayT& weights, + const CompressedWeights& weights, Activations& activations) { - static constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kModelDim = TConfig::kModelDim; GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = EmbeddingScaling(); HWY_DASSERT(token >= 0); @@ -621,13 +588,13 @@ HWY_NOINLINE void EmbedToken(int token, size_t token_idx, size_t pos, }; } -template +template HWY_NOINLINE void TransformerLayer( size_t num_tokens, size_t pos, size_t layer, - const LayerWeightArrayT* layer_weights, + const CompressedLayer* layer_weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool) { - static constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kModelDim = TConfig::kModelDim; auto type = TConfig::kLayerConfig[layer]; size_t layer_of_type = NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); @@ -635,11 +602,11 @@ HWY_NOINLINE void TransformerLayer( layer_weights->pre_attention_norm_scale.data(), activations.pre_att_rms_out.data(), kModelDim); if (type == LayerAttentionType::kGemma) { - Attention(pos, num_tokens, layer_of_type, activations, - layer_weights, kv_cache, pool); + Attention(pos, num_tokens, layer_of_type, activations, layer_weights, + kv_cache, pool); } else { - GriffinRecurrent(pos, num_tokens, layer_of_type, activations, - layer_weights, kv_cache, pool); + GriffinRecurrent(pos, num_tokens, layer_of_type, activations, layer_weights, + kv_cache, pool); } if (TConfig::kPostNormScale) { RMSNormInplaceBatched( @@ -651,7 +618,7 @@ HWY_NOINLINE void TransformerLayer( RMSNormBatched(num_tokens, activations.x.data(), layer_weights->pre_ffw_norm_scale.data(), activations.bf_pre_ffw_rms_out.data(), kModelDim); - FFW(activations, num_tokens, layer_weights, pool); + FFW(activations, num_tokens, layer_weights, pool); if (TConfig::kPostNormScale) { RMSNormInplaceBatched(num_tokens, layer_weights->post_ffw_norm_scale.data(), @@ -661,9 +628,9 @@ HWY_NOINLINE void TransformerLayer( activations.x.data(), kModelDim); } -template +template HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, - const WeightArrayT& weights, + const CompressedWeights& weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool) { PROFILER_ZONE("Gen.Prefill\\Att\\FFW"); @@ -685,9 +652,9 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos, // Compute the transformer for a batch of input tokens. During generation, // we usually have num_tokens == 1 (and also kBatchSize == 1). -template +template HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos, - const WeightArrayT& weights, + const CompressedWeights& weights, Activations& activations, KVCache& kv_cache, hwy::ThreadPool& pool, const LayersOutputFunc& layers_output) { @@ -698,17 +665,18 @@ HWY_NOINLINE void Transformer(const int* tokens, size_t num_tokens, size_t pos, layers_output(pos + token_idx, "Tokens", &token_f, 1); } } - static constexpr size_t kModelDim = TConfig::kModelDim; + constexpr size_t kModelDim = TConfig::kModelDim; for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { EmbedToken(tokens[token_idx], token_idx, pos, weights, activations); } for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { - const auto* layer_weights = weights.GetLayer(layer); + const CompressedLayer* layer_weights = weights.GetLayer(layer); TransformerLayer(num_tokens, pos, layer, layer_weights, activations, kv_cache, pool); + if (layers_output) { - std::string block_name = "blocks." + std::to_string(layer); + const std::string block_name = "blocks." + std::to_string(layer); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { layers_output(pos + token_idx, block_name, activations.x.data() + token_idx * kModelDim, kModelDim); @@ -754,11 +722,8 @@ void RangeChecks(size_t& max_tokens, size_t& max_generated_tokens, fprintf(stderr, "%zu\n", prompt_size); } } -} -template -const CompressedWeights& GetWeights(const ByteStorageT& weights_u8) { - return *reinterpret_cast*>(weights_u8.get()); + HWY_ASSERT(prompt_size > 0); } template @@ -776,12 +741,13 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, const RuntimeConfig& runtime_config, const std::vector& prompt, size_t pos, KVCache& kv_cache, hwy::ThreadPool& pool, TimingInfo& timing_info) { - const CompressedWeights& weights = GetWeights(weights_u8); + const CompressedWeights& weights = + *reinterpret_cast*>(weights_u8.get()); auto& prefill_activations = GetActivations(prefill_u8); auto& activations = GetActivations(decode_u8); - static constexpr size_t kVocabSize = TConfig::kVocabSize; + constexpr size_t kVocabSize = TConfig::kVocabSize; size_t prompt_size = prompt.size(); size_t max_tokens = runtime_config.max_tokens; size_t max_generated_tokens = runtime_config.max_generated_tokens; @@ -791,7 +757,6 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, max_tokens); return; } - HWY_ASSERT(prompt_size > 0); // If no sample_func is provided, we use top-k sampling. const SampleFunc sample_token = @@ -825,8 +790,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, HWY_DASSERT(batch_size <= kPrefillBatchSize); HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); const int* batch_tokens = prompt.data() + pos_offset; - Prefill(batch_tokens, batch_size, pos, weights, - prefill_activations, kv_cache, pool); + Prefill(batch_tokens, batch_size, pos, weights, prefill_activations, + kv_cache, pool); for (size_t idx = 0; idx < batch_size; ++idx) { if (!runtime_config.stream_token(batch_tokens[idx], 0.0f)) return; } @@ -834,11 +799,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, pos_offset += batch_size; } - if (runtime_config.verbosity >= 2) { - const double prefill_end = hwy::platform::Now(); - timing_info.prefill_tok_sec = - static_cast(pos_offset) / (prefill_end - prefill_start); - } + timing_info.prefill_tok_sec = + static_cast(pos_offset) / (hwy::platform::Now() - prefill_start); // Start generation. const double gen_start = hwy::platform::Now(); @@ -851,9 +813,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, for (size_t generate_pos = 0; pos < max_tokens && generate_pos < max_generated_tokens; ++pos, ++pos_offset, ++generate_pos) { - Transformer(&token, kDecodeBatchSize, pos, weights, - activations, kv_cache, pool, - runtime_config.layers_output); + Transformer(&token, kDecodeBatchSize, pos, weights, activations, kv_cache, + pool, runtime_config.layers_output); float token_logit = 0.0f; // The condition below is always true if we are doing Prefill above. // We keep it here for clarity so that the code is correct even if Prefill @@ -885,11 +846,8 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8, break; } } - if (runtime_config.verbosity >= 2) { - const double gen_end = hwy::platform::Now(); - timing_info.gen_tok_sec = - static_cast(pos_offset - pos_gen_start) / (gen_end - gen_start); - } + timing_info.gen_tok_sec = static_cast(pos_offset - pos_gen_start) / + (hwy::platform::Now() - gen_start); } } // namespace HWY_NAMESPACE @@ -901,18 +859,13 @@ namespace gcpp { namespace { template -struct AllocatePrefill { - ByteStorageT operator()() const { - return AllocateSizeof>(); +struct AllocateState { + void operator()(ByteStorageT& prefill, ByteStorageT& decode) const { + prefill = AllocateSizeof>(); + decode = AllocateSizeof>(); } }; -template -struct AllocateDecode { - ByteStorageT operator()() const { - return AllocateSizeof>(); - } -}; } // namespace Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, @@ -922,8 +875,8 @@ Gemma::Gemma(const Path& tokenizer_path, const Path& weights, Model model_type, model_type_(model_type), weight_type_(weight_type) { weights_u8_ = LoadCompressedWeights(weights, model_type, weight_type, pool); - prefill_u8_ = CallForModelAndWeight(model_type, weight_type); - decode_u8_ = CallForModelAndWeight(model_type, weight_type); + CallForModelAndWeight(model_type, weight_type, prefill_u8_, + decode_u8_); } Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, @@ -935,8 +888,8 @@ Gemma::Gemma(GemmaTokenizer&& tokenizer, Model model_type, Type weight_type, HWY_ASSERT(weight_type == Type::kF32); weights_u8_ = CallForModel( model_type, pool); - prefill_u8_ = CallForModelAndWeight(model_type, weight_type); - decode_u8_ = CallForModelAndWeight(model_type, weight_type); + CallForModelAndWeight(model_type, weight_type, prefill_u8_, + decode_u8_); } Gemma::~Gemma() { diff --git a/gemma/ops.h b/gemma/ops.h index 18b638b..0396883 100644 --- a/gemma/ops.h +++ b/gemma/ops.h @@ -1629,6 +1629,36 @@ static HWY_NOINLINE HWY_MAYBE_UNUSED void AddFrom( HWY_ATTR { return hn::Add(x, other); }); } +// Simple loops unless/until batch sizes are large enough to parallelize. +template +void RMSNormBatched(size_t num_tokens, const float* activations, + const WeightT* weights, OutT* out, const size_t model_dim) { + HWY_DASSERT(num_tokens <= kBatchSize); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNorm(activations + token_idx * model_dim, weights, + out + token_idx * model_dim, model_dim); + } +} + +template +void RMSNormInplaceBatched(size_t num_tokens, const WeightT* weights, + InOutT* inout, const size_t model_dim) { + HWY_DASSERT(num_tokens <= kBatchSize); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + RMSNormInplace(weights, inout + token_idx * model_dim, model_dim); + } +} + +template +void AddFromBatched(size_t num_tokens, const float* other, float* x, + const size_t model_dim) { + HWY_DASSERT(num_tokens <= kBatchSize); + for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) { + AddFrom(other + token_idx * model_dim, x + token_idx * model_dim, + model_dim); + } +} + static HWY_NOINLINE void MulBy(const float* HWY_RESTRICT other, float* HWY_RESTRICT x, const size_t size, const size_t max_pos) {