mirror of https://github.com/google/gemma.cpp.git
commit
b27d8d6b92
152
gemma/gemma.cc
152
gemma/gemma.cc
|
|
@ -1057,80 +1057,96 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
// In single-turn (non-chat) usage, pos and pos_offset start at 0 and are
|
||||||
// always equal.
|
// always equal.
|
||||||
size_t pos_offset = 0; // offset relative to pos
|
size_t pos_offset = 0; // offset relative to pos
|
||||||
const double prefill_start = hwy::platform::Now();
|
|
||||||
|
|
||||||
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
auto prefill_phase = [&]() HWY_ATTR {
|
||||||
// first input token for generation.
|
bool keep_on = true;
|
||||||
while (pos_offset < prompt_size - 1) {
|
const double prefill_start = hwy::platform::Now();
|
||||||
const size_t batch_size =
|
|
||||||
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
|
|
||||||
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
|
||||||
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
|
||||||
const int* batch_tokens = prompt.data() + pos_offset;
|
|
||||||
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
|
||||||
prefill_activations, kv_cache, pool, inner_pool);
|
|
||||||
for (size_t idx = 0; idx < batch_size; ++idx) {
|
|
||||||
stream_token(batch_tokens[idx], 0.0f);
|
|
||||||
}
|
|
||||||
pos += batch_size;
|
|
||||||
pos_offset += batch_size;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (verbosity >= 2) {
|
// Prefill stops before prompt_size - 1 since the last prompt token is the
|
||||||
// in the future this output should not occur in GenerateImpl but instead
|
// first input token for generation.
|
||||||
// should be available as observable state for frontend code to handle I/O.
|
while (pos_offset < prompt_size - 1 && keep_on) {
|
||||||
const double prefill_end = hwy::platform::Now();
|
const size_t batch_size =
|
||||||
const double prefill_tok_sec =
|
std::min(kPrefillBatchSize, prompt_size - 1 - pos_offset);
|
||||||
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
HWY_DASSERT(batch_size <= kPrefillBatchSize);
|
||||||
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
HWY_DASSERT(pos_offset + batch_size <= prompt_size - 1);
|
||||||
}
|
const int* batch_tokens = prompt.data() + pos_offset;
|
||||||
|
Prefill<kPrefillBatchSize>(batch_tokens, batch_size, pos, weights,
|
||||||
const double gen_start = hwy::platform::Now();
|
prefill_activations, kv_cache, pool, inner_pool);
|
||||||
|
for (size_t idx = 0; idx < batch_size; ++idx) {
|
||||||
HWY_DASSERT(pos_offset == prompt_size - 1);
|
keep_on = stream_token(batch_tokens[idx], 0.0f);
|
||||||
|
if(!keep_on) {
|
||||||
size_t pos_gen_start = pos_offset;
|
break;
|
||||||
int token = prompt.at(pos_offset);
|
}
|
||||||
stream_token(token, 0);
|
|
||||||
for (size_t generate_pos = 0;
|
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
|
||||||
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
|
||||||
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
|
||||||
layers_output);
|
|
||||||
float* final_activation = activations.x.data();
|
|
||||||
// The condition below is always true if we are doing Prefill above.
|
|
||||||
// We keep it here for clarity so that the code is correct even if Prefill
|
|
||||||
// is disabled.
|
|
||||||
if (is_generating_phase) {
|
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
|
||||||
// Generation phase
|
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
|
||||||
0, final_activation,
|
|
||||||
activations.logits.data(), pool);
|
|
||||||
// Barrier: must have all logits so we can subtract max.
|
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
|
||||||
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
|
||||||
gen, temperature, accept_token);
|
|
||||||
if (!stream_token(token, activations.logits[token])) {
|
|
||||||
token = EOS_ID;
|
|
||||||
}
|
}
|
||||||
} else {
|
pos += batch_size;
|
||||||
// We would take this branch if we were not doing Prefill but would
|
pos_offset += batch_size;
|
||||||
// process the tokens of the prompt one at a time.
|
|
||||||
token = prompt.at(pos_offset + 1);
|
|
||||||
stream_token(token, 0);
|
|
||||||
}
|
}
|
||||||
if (token == EOS_ID) {
|
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
const double gen_end = hwy::platform::Now();
|
// in the future this output should not occur in GenerateImpl but instead
|
||||||
const double gen_tok_sec =
|
// should be available as observable state for frontend code to handle I/O.
|
||||||
static_cast<double>(pos_offset - pos_gen_start) /
|
const double prefill_end = hwy::platform::Now();
|
||||||
(gen_end - gen_start);
|
const double prefill_tok_sec =
|
||||||
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
static_cast<double>(pos_offset) / (prefill_end - prefill_start);
|
||||||
|
std::cout << "\n[ Prefill tokens / sec = " << prefill_tok_sec << " ]";
|
||||||
|
}
|
||||||
|
|
||||||
|
return keep_on;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto transform_phase = [&]() HWY_ATTR {
|
||||||
|
|
||||||
|
const double gen_start = hwy::platform::Now();
|
||||||
|
|
||||||
|
HWY_DASSERT(pos_offset == prompt_size - 1);
|
||||||
|
|
||||||
|
size_t pos_gen_start = pos_offset;
|
||||||
|
int token = prompt.at(pos_offset);
|
||||||
|
stream_token(token, 0);
|
||||||
|
for (size_t generate_pos = 0;
|
||||||
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
|
const bool is_generating_phase = pos_offset >= prompt_size - 1;
|
||||||
|
Transformer(token, pos, weights, activations, kv_cache, pool, inner_pool,
|
||||||
|
layers_output);
|
||||||
|
float* final_activation = activations.x.data();
|
||||||
|
// The condition below is always true if we are doing Prefill above.
|
||||||
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
|
// is disabled.
|
||||||
|
if (is_generating_phase) {
|
||||||
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
|
// Generation phase
|
||||||
|
MatVec<kVocabSize, TConfig::kModelDim>(weights.embedder_input_embedding,
|
||||||
|
0, final_activation,
|
||||||
|
activations.logits.data(), pool);
|
||||||
|
// Barrier: must have all logits so we can subtract max.
|
||||||
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
|
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
||||||
|
gen, temperature, accept_token);
|
||||||
|
if (!stream_token(token, activations.logits[token])) {
|
||||||
|
token = EOS_ID;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// We would take this branch if we were not doing Prefill but would
|
||||||
|
// process the tokens of the prompt one at a time.
|
||||||
|
token = prompt.at(pos_offset + 1);
|
||||||
|
stream_token(token, 0);
|
||||||
|
}
|
||||||
|
if (token == EOS_ID) {
|
||||||
|
if (verbosity >= 2) {
|
||||||
|
const double gen_end = hwy::platform::Now();
|
||||||
|
const double gen_tok_sec =
|
||||||
|
static_cast<double>(pos_offset - pos_gen_start) /
|
||||||
|
(gen_end - gen_start);
|
||||||
|
std::cout << "\n[ Generation tokens / sec = " << gen_tok_sec << " ]\n";
|
||||||
|
}
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
break;
|
|
||||||
}
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if(prefill_phase()) {
|
||||||
|
transform_phase();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue