Simplify prefill early-exit (originally Merge #156)

PiperOrigin-RevId: 627788524
This commit is contained in:
Paul Chang 2024-04-24 11:11:04 -07:00 committed by Copybara-Service
parent b27d8d6b92
commit 75eca87039
1 changed files with 67 additions and 81 deletions

View File

@ -1057,96 +1057,82 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are // In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal. // always equal.
size_t pos_offset = 0; // offset relative to pos size_t pos_offset = 0; // offset relative to pos
const double prefill_start = hwy::platform::Now();
auto prefill_phase = [&]() HWY_ATTR { // Prefill stops before prompt_size - 1 since the last prompt token is the
bool keep_on = true; // first input token for generation.
const double prefill_start = hwy::platform::Now(); while (pos_offset < prompt_size - 1) {
const size_t batch_size =
// Prefill stops before prompt_size - 1 since the last prompt token is the std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
// first input token for generation. HWY_DASSERT(batch_size <= kPrefillBatchSize);
while (pos_offset < prompt_size - 1 && keep_on) { HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
const size_t batch_size = const int* batch_tokens = prompt.data() + pos_offset;
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset); Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
HWY_DASSERT(batch_size <= kPrefillBatchSize); prefill_activations, kv_cache, pool, inner_pool);
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1); for (size_t idx = 0; idx < batch_size; ++idx) {
const int* batch_tokens = prompt.data() + pos_offset; if (!stream_token(batch_tokens[idx], 0.0f)) return;
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
keep_on = stream_token(batch_tokens[idx], 0.0f);
if(!keep_on) {
break;
}
}
pos += batch_size;
pos_offset += batch_size;
} }
pos += batch_size;
pos_offset += batch_size;
}
if (verbosity >= 2) { if (verbosity >= 2) {
// in the future this output should not occur in GenerateImpl but instead // in the future this output should not occur in GenerateImpl but instead
// should be available as observable state for frontend code to handle I/O. // should be available as observable state for frontend code to handle I/O.
const double prefill_end = hwy::platform::Now(); const double prefill_end = hwy::platform::Now();
const double prefill_tok_sec = const double prefill_tok_sec =
static_cast<double>(pos_offset) / (prefill_end - prefill_start); static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
} }
return keep_on; const double gen_start = hwy::platform::Now();
};
auto transform_phase = [&]() HWY_ATTR { HWY_DASSERT(pos_offset == prompt_size - 1);
const double gen_start = hwy::platform::Now(); size_t pos_gen_start = pos_offset;
int token = prompt.at(pos_offset);
HWY_DASSERT(pos_offset == prompt_size - 1); stream_token(token, 0);
for (size_t generate_pos = 0;
size_t pos_gen_start = pos_offset; pos < max_tokens && generate_pos < max_generated_tokens;
int token = prompt.at(pos_offset); ++pos, ++pos_offset, ++generate_pos) {
stream_token(token, 0); const bool is_generating_phase = pos_offset >= prompt_size - 1;
for (size_t generate_pos = 0; Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
pos < max_tokens && generate_pos < max_generated_tokens; layers_output);
++pos, ++pos_offset, ++generate_pos) { float* final_activation = activations.x.data();
const bool is_generating_phase = pos_offset >= prompt_size - 1; // The condition below is always true if we are doing Prefill above.
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool, // We keep it here for clarity so that the code is correct even if Prefill
layers_output); // is disabled.
float* final_activation = activations.x.data(); if (is_generating_phase) {
// The condition below is always true if we are doing Prefill above. PROFILER_ZONE("Gen.Embedding");
// We keep it here for clarity so that the code is correct even if Prefill // Generation phase
// is disabled. MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
if (is_generating_phase) { 0, final_activation,
PROFILER_ZONE("Gen.Embedding"); activations.logits.data(), pool);
// Generation phase // Barrier: must have all logits so we can subtract max.
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding, Softmax(activations.logits.data(), kVocabSize);
0, final_activation, token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
activations.logits.data(), pool); gen, temperature, accept_token);
// Barrier: must have all logits so we can subtract max. if (!stream_token(token, activations.logits[token])) {
Softmax(activations.logits.data(), kVocabSize); token = EOS_ID;
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
gen, temperature, accept_token);
if (!stream_token(token, activations.logits[token])) {
token = EOS_ID;
}
} else {
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
} }
if (token == EOS_ID) { } else {
if (verbosity >= 2) { // We would take this branch if we were not doing Prefill but would
const double gen_end = hwy::platform::Now(); // process the tokens of the prompt one at a time.
const double gen_tok_sec = token = prompt.at(pos_offset + 1);
static_cast<double>(pos_offset - pos_gen_start) / if (!stream_token(token, 0)) {
(gen_end - gen_start); token = EOS_ID;
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
} }
} }
}; if (token == EOS_ID) {
if (verbosity >= 2) {
if(prefill_phase()) { const double gen_end = hwy::platform::Now();
transform_phase(); const double gen_tok_sec =
static_cast<double>(pos_offset - pos_gen_start) /
(gen_end - gen_start);
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
}
break;
}
} }
} }