Extends Transformer() to prepare for batched processing.

PiperOrigin-RevId: 642603025
This commit is contained in:
Daniel Keysers 2024-06-12 07:00:33 -07:00 committed by Copybara-Service
parent 2a0e6ee976
commit 1ac9857014
2 changed files with 84 additions and 64 deletions

View File

@ -17,6 +17,7 @@
// Compiles this file for multiple architectures via "foreach_target.h", to // Compiles this file for multiple architectures via "foreach_target.h", to
// which we pass the filename via macro 'argument'. // which we pass the filename via macro 'argument'.
#include <cstdio>
#undef HWY_TARGET_INCLUDE #undef HWY_TARGET_INCLUDE
#define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT #define HWY_TARGET_INCLUDE "gemma/gemma.cc" // NOLINT
#include "hwy/foreach_target.h" // IWYU pragma: keep #include "hwy/foreach_target.h" // IWYU pragma: keep
@ -592,16 +593,15 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
pool.Run( pool.Run(
0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR { 0, num_tokens, [&](const uint64_t token_idx, size_t /*thread*/) HWY_ATTR {
const int token = tokens[token_idx]; const int token = tokens[token_idx];
HWY_ASSERT(token >= 0); HWY_DASSERT(token >= 0);
HWY_ASSERT(token < TConfig::kVocabSize); HWY_DASSERT(token < TConfig::kVocabSize);
Decompress(weights.embedder_input_embedding, token * kModelDim, Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim); activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim, MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
kModelDim); kModelDim);
if constexpr (TConfig::kAbsolutePE) { if constexpr (TConfig::kAbsolutePE) {
AddAbsolutePositionalEmbeddings( AddAbsolutePositionalEmbeddings(
activations.x.data() + token_idx * kModelDim, TConfig::kModelDim, activations.x.data() + token_idx * kModelDim, kModelDim, pos);
pos);
}; };
}); });
@ -646,72 +646,92 @@ HWY_NOINLINE void Prefill(const int* tokens, size_t num_tokens, size_t pos,
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
} }
// n = 1 specialization // Compute the transformer for a batch of input tokens. During generation,
template <typename WeightArrayT, class TConfig> // we usually have num_tokens == 1 (and also kBatchSize == 1).
HWY_NOINLINE void Transformer(int token, size_t pos, template <size_t kBatchSize, typename WeightArrayT, class TConfig>
HWY_NOINLINE void Transformer(const int *tokens, size_t num_tokens, size_t pos,
const WeightArrayT& weights, const WeightArrayT& weights,
Activations<TConfig, 1>& activations, Activations<TConfig, kBatchSize>& activations,
KVCache& kv_cache, hwy::ThreadPool& pool, KVCache& kv_cache, hwy::ThreadPool& pool,
LayersOutputT* layers_output) { LayersOutputT* layers_output) {
HWY_ASSERT(num_tokens <= kBatchSize);
if (layers_output != nullptr) { if (layers_output != nullptr) {
float token_f = token; for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos, "Tokens", &token_f, 1); float token_f = tokens[token_idx];
(*layers_output)(pos + token_idx, "Tokens", &token_f, 1);
}
} }
static constexpr size_t kModelDim = TConfig::kModelDim; static constexpr size_t kModelDim = TConfig::kModelDim;
Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data(), kModelDim);
GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling = GEMMA_CONSTEXPR_EMBSCALING const float kEmbScaling =
EmbeddingScaling<TConfig>(); EmbeddingScaling<TConfig>();
MulByConst(kEmbScaling, activations.x.data(), kModelDim); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
if constexpr (TConfig::kAbsolutePE) { const int token = tokens[token_idx];
AddAbsolutePositionalEmbeddings(activations.x.data(), TConfig::kModelDim, HWY_DASSERT(token >= 0);
pos); HWY_DASSERT(token < TConfig::kVocabSize);
}; Decompress(weights.embedder_input_embedding, token * kModelDim,
activations.x.data() + token_idx * kModelDim, kModelDim);
MulByConst(kEmbScaling, activations.x.data() + token_idx * kModelDim,
kModelDim);
if constexpr (TConfig::kAbsolutePE) {
AddAbsolutePositionalEmbeddings(
activations.x.data() + token_idx * kModelDim, kModelDim,
pos + token_idx);
};
}
for (size_t layer = 0; layer < TConfig::kLayers; ++layer) { for (size_t layer = 0; layer < TConfig::kLayers; ++layer) {
auto type = TConfig::kLayerConfig[layer]; auto type = TConfig::kLayerConfig[layer];
const auto* layer_weights = weights.GetLayer(layer); const auto* layer_weights = weights.GetLayer(layer);
size_t layer_of_type = size_t layer_of_type =
NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer); NumLayersOfTypeBefore(TConfig::kLayerConfig, type, layer);
RMSNormBatched<1>(1, activations.x.data(), RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
layer_weights->pre_attention_norm_scale.data(), layer_weights->pre_attention_norm_scale.data(),
activations.pre_att_rms_out.data(), kModelDim); activations.pre_att_rms_out.data(), kModelDim);
if (type == LayerAttentionType::kGemma) { if (type == LayerAttentionType::kGemma) {
Attention<1>(pos, 1, layer_of_type, activations, layer_weights, kv_cache, Attention<kBatchSize>(pos, num_tokens, layer_of_type, activations,
pool); layer_weights, kv_cache, pool);
} else { } else {
GriffinRecurrent<1>(pos, 1, layer_of_type, activations, layer_weights, GriffinRecurrent<kBatchSize>(pos, num_tokens, layer_of_type, activations,
kv_cache, pool); layer_weights, kv_cache, pool);
} }
if (TConfig::kPostNormScale) { if (TConfig::kPostNormScale) {
RMSNormInplaceBatched<1>(1, RMSNormInplaceBatched<kBatchSize>(
layer_weights->post_attention_norm_scale.data(), num_tokens, layer_weights->post_attention_norm_scale.data(),
activations.att_post2.data(), kModelDim); activations.att_post2.data(), kModelDim);
} }
AddFromBatched<1>(1, activations.att_post2.data(), activations.x.data(), AddFromBatched<kBatchSize>(num_tokens, activations.att_post2.data(),
kModelDim); activations.x.data(), kModelDim);
RMSNormBatched<1>(1, activations.x.data(), RMSNormBatched<kBatchSize>(num_tokens, activations.x.data(),
layer_weights->pre_ffw_norm_scale.data(), layer_weights->pre_ffw_norm_scale.data(),
activations.bf_pre_ffw_rms_out.data(), kModelDim); activations.bf_pre_ffw_rms_out.data(),
FFW<1>(activations, /* num_tokens = */ 1, layer_weights, pool); kModelDim);
FFW<kBatchSize>(activations, num_tokens, layer_weights, pool);
if (TConfig::kPostNormScale) { if (TConfig::kPostNormScale) {
RMSNormInplaceBatched<1>(1, layer_weights->post_ffw_norm_scale.data(), RMSNormInplaceBatched<kBatchSize>(
activations.ffw_out.data(), kModelDim); num_tokens, layer_weights->post_ffw_norm_scale.data(),
activations.ffw_out.data(), kModelDim);
} }
AddFromBatched<1>(1, activations.ffw_out.data(), activations.x.data(), AddFromBatched<kBatchSize>(num_tokens, activations.ffw_out.data(),
kModelDim); activations.x.data(), kModelDim);
if (layers_output != nullptr) { if (layers_output != nullptr) {
std::string block_name = "blocks." + std::to_string(layer); std::string block_name = "blocks." + std::to_string(layer);
(*layers_output)(pos, block_name, activations.x.data(), kModelDim); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, block_name,
activations.x.data() + token_idx * kModelDim,
kModelDim);
}
} }
} }
// Placeholder for internal test4, do not remove // Placeholder for internal test4, do not remove
RMSNormInplaceBatched<1>(1, weights.final_norm_scale.data(), RMSNormInplaceBatched<kBatchSize>(num_tokens, weights.final_norm_scale.data(),
activations.x.data(), kModelDim); activations.x.data(), kModelDim);
if (layers_output != nullptr) { if (layers_output != nullptr) {
(*layers_output)(pos, "final_norm", activations.x.data(), kModelDim); for (size_t token_idx = 0; token_idx < num_tokens; ++token_idx) {
(*layers_output)(pos + token_idx, "final_norm",
activations.x.data() + token_idx * kModelDim, kModelDim);
}
} }
} }
@ -781,6 +801,7 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
} }
HWY_ASSERT(prompt_size > 0); HWY_ASSERT(prompt_size > 0);
// If no sample_func is provided, we use top-k sampling.
const SampleFunc sample_token = const SampleFunc sample_token =
runtime_config.sample_func runtime_config.sample_func
? runtime_config.sample_func ? runtime_config.sample_func
@ -827,36 +848,35 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
static_cast<double>(pos_offset) / (prefill_end - prefill_start); static_cast<double>(pos_offset) / (prefill_end - prefill_start);
} }
// Start generation.
const double gen_start = hwy::platform::Now(); const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt_size - 1); HWY_DASSERT(pos_offset == prompt_size - 1);
size_t pos_gen_start = pos_offset; size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset); int token = prompt.at(pos_offset);
runtime_config.stream_token(token, 0); // The loop below is not yet prepared for batch size > 1.
HWY_ASSERT(kDecodeBatchSize == 1);
if (!runtime_config.stream_token(token, 0.0f)) return;
for (size_t generate_pos = 0; for (size_t generate_pos = 0;
pos < max_tokens && generate_pos < max_generated_tokens; pos < max_tokens && generate_pos < max_generated_tokens;
++pos, ++pos_offset, ++generate_pos) { ++pos, ++pos_offset, ++generate_pos) {
const bool is_generating_phase = pos_offset >= prompt_size - 1; Transformer<kDecodeBatchSize>(&token, kDecodeBatchSize, pos, weights,
Transformer(token, pos, weights, activations, kv_cache, pool, activations, kv_cache, pool, layers_output);
layers_output); float token_logit = 0.0f;
float* final_activation = activations.x.data();
// The condition below is always true if we are doing Prefill above. // The condition below is always true if we are doing Prefill above.
// We keep it here for clarity so that the code is correct even if Prefill // We keep it here for clarity so that the code is correct even if Prefill
// is disabled. // is disabled.
const bool is_generating_phase = pos_offset >= prompt_size - 1;
if (is_generating_phase) { if (is_generating_phase) {
PROFILER_ZONE("Gen.Embedding"); PROFILER_ZONE("Gen.Embedding");
// Generation phase // Compute logits from last layer activations.
MatVec<kVocabSize, TConfig::kModelDim>( MatVec<kVocabSize, TConfig::kModelDim>(
weights.embedder_input_embedding, 0, final_activation, weights.embedder_input_embedding, 0, activations.x.data(),
activations.even_odd.data(), activations.logits.data(), pool); activations.even_odd.data(), activations.logits.data(), pool);
LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize); LogitsSoftCap(30.0f, activations.logits.data(), kVocabSize);
// Barrier: must have all logits so we can subtract max. // Barrier: must have all logits so we can subtract max.
Softmax(activations.logits.data(), kVocabSize); Softmax(activations.logits.data(), kVocabSize);
token = sample_token(activations.logits.data(), kVocabSize); token = sample_token(activations.logits.data(), kVocabSize);
if (!runtime_config.stream_token(token, activations.logits[token])) { token_logit = activations.logits[token];
token = runtime_config.eos_id;
}
if (generate_pos == 0) { if (generate_pos == 0) {
timing_info.time_to_first_token = hwy::platform::Now() - gen_start; timing_info.time_to_first_token = hwy::platform::Now() - gen_start;
} }
@ -864,20 +884,19 @@ void GenerateT(const ByteStorageT& weights_u8, const ByteStorageT& prefill_u8,
// We would take this branch if we were not doing Prefill but would // We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time. // process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1); token = prompt.at(pos_offset + 1);
if (!runtime_config.stream_token(token, 0)) { }
token = runtime_config.eos_id; if (!runtime_config.stream_token(token, token_logit)) {
} token = runtime_config.eos_id;
} }
if (token == runtime_config.eos_id) { if (token == runtime_config.eos_id) {
if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
}
break; break;
} }
} }
if (runtime_config.verbosity >= 2) {
const double gen_end = hwy::platform::Now();
timing_info.gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) / (gen_end - gen_start);
}
} }
} // namespace HWY_NAMESPACE } // namespace HWY_NAMESPACE
@ -898,7 +917,7 @@ struct AllocatePrefill {
template <typename TConfig> template <typename TConfig>
struct AllocateDecode { struct AllocateDecode {
ByteStorageT operator()() const { ByteStorageT operator()() const {
return AllocateSizeof<Activations<TConfig, 1>>(); return AllocateSizeof<Activations<TConfig, kDecodeBatchSize>>();
} }
}; };
} // namespace } // namespace

View File

@ -31,6 +31,7 @@
namespace gcpp { namespace gcpp {
constexpr size_t kPrefillBatchSize = 16; constexpr size_t kPrefillBatchSize = 16;
constexpr size_t kDecodeBatchSize = 1;
constexpr bool kSystemPrompt = false; constexpr bool kSystemPrompt = false;
struct KVCache { struct KVCache {