This commit is contained in:
Charles Chan 2024-04-22 09:44:19 +00:00 committed by GitHub
commit ba02c73bf0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 79 additions and 66 deletions

View File

@ -1033,9 +1033,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
size_t pos_offset = 0; // offset relative to pos size_t pos_offset = 0; // offset relative to pos
const double prefill_start = hwy::platform::Now(); const double prefill_start = hwy::platform::Now();
auto prefill_phase = [&]() HWY_ATTR {
bool keep_on = true;
// Prefill stops before prompt_size - 1 since the last prompt token is the // Prefill stops before prompt_size - 1 since the last prompt token is the
// first input token for generation. // first input token for generation.
while (pos_offset < prompt_size - 1) { while (pos_offset < prompt_size - 1 && keep_on) {
const size_t batch_size = const size_t batch_size =
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset); std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
HWY_DASSERT(batch_size <= kPrefillBatchSize); HWY_DASSERT(batch_size <= kPrefillBatchSize);
@ -1044,7 +1046,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights, Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
prefill_activations, kv_cache, pool, inner_pool); prefill_activations, kv_cache, pool, inner_pool);
for (size_t idx = 0; idx < batch_size; ++idx) { for (size_t idx = 0; idx < batch_size; ++idx) {
stream_token(batch_tokens[idx], 0.0f); keep_on = stream_token(batch_tokens[idx], 0.0f);
if(!keep_on) {
break;
}
} }
pos += batch_size; pos += batch_size;
pos_offset += batch_size; pos_offset += batch_size;
@ -1058,7 +1063,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
static_cast<double>(pos_offset) / (prefill_end - prefill_start); static_cast<double>(pos_offset) / (prefill_end - prefill_start);
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]"; std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
} }
return keep_on;
};
auto transform_phase = [&]() HWY_ATTR {
const double gen_start = hwy::platform::Now(); const double gen_start = hwy::platform::Now();
HWY_DASSERT(pos_offset == prompt_size - 1); HWY_DASSERT(pos_offset == prompt_size - 1);
@ -1104,6 +1112,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
break; break;
} }
} }
};
if(prefill_phase()) {
transform_phase();
}
} }
template <class TConfig> template <class TConfig>