mirror of https://github.com/google/gemma.cpp.git
parent
a60b564b88
commit
ce807a31a1
|
|
@ -28,7 +28,6 @@
|
|||
#include "gemma/configs.h"
|
||||
#include "gemma/gemma.h"
|
||||
#include "gemma/weights.h"
|
||||
// Placeholder for internal test4, do not remove
|
||||
#include "paligemma/image.h"
|
||||
#include "util/allocator.h"
|
||||
#include "util/basics.h"
|
||||
|
|
@ -1312,42 +1311,45 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
|||
0.0f);
|
||||
}
|
||||
|
||||
const size_t vocab_size = model.Config().vocab_size;
|
||||
const double gen_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||
// Decode generates one token per query and increments queries_mutable_pos.
|
||||
Transformer(QueriesToken(gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, queries_prefix_end, weights, activations,
|
||||
div_seq_len, kv_caches, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
{
|
||||
const size_t vocab_size = model.Config().vocab_size;
|
||||
const double gen_start = hwy::platform::Now();
|
||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||
// Decode generates one token per query and increments
|
||||
// queries_mutable_pos.
|
||||
Transformer(QueriesToken(gen_tokens.data(), num_queries),
|
||||
queries_mutable_pos, queries_prefix_end, weights, activations,
|
||||
div_seq_len, kv_caches, runtime_config.layers_output,
|
||||
runtime_config.activations_observer);
|
||||
// queries_pos are incremented by Transformer.
|
||||
|
||||
bool all_queries_eos = true;
|
||||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||
/*add=*/nullptr, *activations.env,
|
||||
RowPtrFromBatch(activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
|
||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||
bool all_queries_eos = true;
|
||||
{
|
||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||
// Compute logits from last layer activations.
|
||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||
/*add=*/nullptr, *activations.env,
|
||||
RowPtrFromBatch(activations.logits));
|
||||
}
|
||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
|
||||
vocab_size);
|
||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||
|
||||
const bool is_eos =
|
||||
token_streamer(query_idx_start + query_idx,
|
||||
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
||||
all_queries_eos &= is_eos;
|
||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
|
||||
}
|
||||
if (all_queries_eos) break;
|
||||
} // foreach token to generate
|
||||
|
||||
timing_info.NotifyGenerateDone(gen_start);
|
||||
const bool is_eos =
|
||||
token_streamer(query_idx_start + query_idx,
|
||||
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
||||
all_queries_eos &= is_eos;
|
||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
|
||||
}
|
||||
if (all_queries_eos) break;
|
||||
} // foreach token to generate
|
||||
timing_info.NotifyGenerateDone(gen_start);
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
|
|
|||
Loading…
Reference in New Issue