mirror of https://github.com/google/gemma.cpp.git
Merge pull request #127 from szabadka:gemma3
PiperOrigin-RevId: 621815677
This commit is contained in:
commit
08948f13ac
17
gemma.cc
17
gemma.cc
|
|
@ -662,12 +662,16 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
|
|
||||||
size_t pos_gen_start = pos_offset;
|
size_t pos_gen_start = pos_offset;
|
||||||
int token = prompt.at(pos_offset);
|
int token = prompt.at(pos_offset);
|
||||||
|
stream_token(token, 0);
|
||||||
for (size_t generate_pos = 0;
|
for (size_t generate_pos = 0;
|
||||||
pos < max_tokens && generate_pos < max_generated_tokens;
|
pos < max_tokens && generate_pos < max_generated_tokens;
|
||||||
++pos, ++pos_offset, ++generate_pos) {
|
++pos, ++pos_offset, ++generate_pos) {
|
||||||
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
Transformer(token, pos, c_weights, activations, kv_cache, pool, inner_pool);
|
||||||
float* final_activation = activations.x.data();
|
float* final_activation = activations.x.data();
|
||||||
if (pos_offset >= prompt_size) {
|
// The condition below is always true if we are doing Prefill above.
|
||||||
|
// We keep it here for clarity so that the code is correct even if Prefill
|
||||||
|
// is disabled.
|
||||||
|
if (pos_offset >= prompt_size - 1) {
|
||||||
PROFILER_ZONE("Gen.Embedding");
|
PROFILER_ZONE("Gen.Embedding");
|
||||||
// Generation phase
|
// Generation phase
|
||||||
MatVec<kVocabSize, TConfig::kModelDim>(
|
MatVec<kVocabSize, TConfig::kModelDim>(
|
||||||
|
|
@ -677,9 +681,14 @@ void GenerateImpl(GemmaImpl<TConfig>& gemma, size_t max_tokens,
|
||||||
Softmax(activations.logits.data(), kVocabSize);
|
Softmax(activations.logits.data(), kVocabSize);
|
||||||
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
token = SampleTopK<TConfig::kTopK>(activations.logits.data(), kVocabSize,
|
||||||
gen, temperature, accept_token);
|
gen, temperature, accept_token);
|
||||||
}
|
if (!stream_token(token, activations.logits[token])) {
|
||||||
if (!stream_token(token, activations.logits[token])) {
|
token = EOS_ID;
|
||||||
token = EOS_ID;
|
}
|
||||||
|
} else {
|
||||||
|
// We would take this branch if we were not doing Prefill but would
|
||||||
|
// process the tokens of the prompt one at a time.
|
||||||
|
token = prompt.at(pos_offset + 1);
|
||||||
|
stream_token(token, 0);
|
||||||
}
|
}
|
||||||
if (token == EOS_ID) {
|
if (token == EOS_ID) {
|
||||||
if (verbosity >= 2) {
|
if (verbosity >= 2) {
|
||||||
|
|
|
||||||
3
run.cc
3
run.cc
|
|
@ -116,7 +116,8 @@ void ReplGemma(gcpp::Gemma& model, gcpp::KVCache& kv_cache,
|
||||||
verbosity](int token, float) {
|
verbosity](int token, float) {
|
||||||
++abs_pos;
|
++abs_pos;
|
||||||
++current_pos;
|
++current_pos;
|
||||||
if (current_pos < prompt_size) {
|
// <= since position is incremented before
|
||||||
|
if (current_pos <= prompt_size) {
|
||||||
std::cerr << "." << std::flush;
|
std::cerr << "." << std::flush;
|
||||||
} else if (token == gcpp::EOS_ID) {
|
} else if (token == gcpp::EOS_ID) {
|
||||||
if (!args.multiturn) {
|
if (!args.multiturn) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue