Simplify prefill early-exit (originally Merge #156)

PiperOrigin-RevId: 627788524
This commit is contained in:
Paul Chang 2024-04-24 11:11:04 -07:00 committed by Copybara-Service
parent b27d8d6b92
commit 75eca87039
1 changed files with 67 additions and 81 deletions

View File

@ -1057,14 +1057,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
// always equal.
size_t pos_offset = 0; // offset relative to pos
auto prefill_phase = [&]() HWY_ATTR {
bool keep_on = true;
const double prefill_start = hwy::platform::Now();
// Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation.
while (pos_offset < prompt_size - 1 && keep_on) {
while (pos_offset < prompt_size - 1) {
const size_t batch_size =
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
HWY_DASSERT(batch_size <= kPrefillBatchSize);
@ -1073,10 +1070,7 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) {
keep_on = stream_token(batch_tokens[idx], 0.0f);
if(!keep_on) {
break;
}
if (!stream_token(batch_tokens[idx], 0.0f)) return;
}
pos += batch_size;
pos_offset += batch_size;
@ -1091,11 +1085,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
}
return keep_on;
};
auto transform_phase = [&]() HWY_ATTR {
const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt_size - 1);
@ -1130,7 +1119,9 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
// We would take this branch if we were not doing Prefill but would
// process the tokens of the prompt one at a time.
token = prompt.at(pos_offset + 1);
stream_token(token, 0);
if (!stream_token(token, 0)) {
token = EOS_ID;
}
}
if (token == EOS_ID) {
if (verbosity >= 2) {
@ -1143,11 +1134,6 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
break;
}
}
};
if(prefill_phase()) {
transform_phase();
}
}
#define TOKEN(token_id) TokenString(gemma, token_id).c_str()