mirror of https://github.com/google/gemma.cpp.git
Merge 62f69fe837 into 83dd08ac87
This commit is contained in:
commit
ba02c73bf0
17
gemma.cc
17
gemma.cc
|
|
@ -1033,9 +1033,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
size_t pos_offset = 0; // offset relative to pos
|
size_t pos_offset = 0; // offset relative to pos
|
||||||
const double prefill_start = hwy::platform::Now();
|
const double prefill_start = hwy::platform::Now();
|
||||||
|
|
||||||
|
auto prefill_phase = [&]() HWY_ATTR {
|
||||||
|
bool keep_on = true;
|
||||||
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
||||||
// first input token for generation.
|
// first input token for generation.
|
||||||
while (pos_offset < prompt_size - 1) {
|
while (pos_offset < prompt_size - 1 && keep_on) {
|
||||||
const size_t batch_size =
|
const size_t batch_size =
|
||||||
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
|
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
|
||||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
||||||
|
|
@ -1044,7 +1046,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||||
prefill_activations, kv_cache, pool, inner_pool);
|
prefill_activations, kv_cache, pool, inner_pool);
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
stream_token(batch_tokens[idx], 0.0f);
|
keep_on = stream_token(batch_tokens[idx], 0.0f);
|
||||||
|
if(!keep_on) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
pos += batch_size;
|
pos += batch_size;
|
||||||
pos_offset += batch_size;
|
pos_offset += batch_size;
|
||||||
|
|
@ -1058,7 +1063,10 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
||||||
}
|
}
|
||||||
|
return keep_on;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto transform_phase = [&]() HWY_ATTR {
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
|
|
||||||
HWY_DASSERT(pos_offset == prompt_size - 1);
|
HWY_DASSERT(pos_offset == prompt_size - 1);
|
||||||
|
|
@ -1104,6 +1112,11 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if(prefill_phase()) {
|
||||||
|
transform_phase();
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <class TConfig>
|
template <class TConfig>
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue