mirror of https://github.com/google/gemma.cpp.git
Factor out DecodeStepT from GenerateT into a separate function.
This will be useful for adding sampling functionality like beam decoding, parallel sampling, cot decoding (as described in the [Chain-of-Thought Reasoning Without Prompting paper](https://arxiv.org/abs/2402.10200)) PiperOrigin-RevId: 725151530
This commit is contained in:
parent
b0fe9a43e6
commit
9b3e7ea8a2
|
|
@ -27,6 +27,7 @@
|
||||||
#include "gemma/common.h"
|
#include "gemma/common.h"
|
||||||
#include "gemma/configs.h"
|
#include "gemma/configs.h"
|
||||||
#include "gemma/gemma.h"
|
#include "gemma/gemma.h"
|
||||||
|
#include "gemma/kv_cache.h"
|
||||||
#include "gemma/weights.h"
|
#include "gemma/weights.h"
|
||||||
#include "paligemma/image.h"
|
#include "paligemma/image.h"
|
||||||
#include "util/allocator.h"
|
#include "util/allocator.h"
|
||||||
|
|
@ -1217,6 +1218,54 @@ HWY_INLINE SampleFunc ChooseSampleFunc(const RuntimeConfig& runtime_config) {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
// Runs one decode step for all the queries in the batch. Returns true if all
|
||||||
|
// queries are at <end_of_sentence>.
|
||||||
|
bool DecodeStepT(const ModelWeightsPtrs<T>& weights,
|
||||||
|
const RuntimeConfig& runtime_config,
|
||||||
|
const QueriesPromptTokens& queries_prompt,
|
||||||
|
const size_t query_idx_start, const KVCaches& kv_caches,
|
||||||
|
const QueriesPos& queries_prefix_end,
|
||||||
|
const hwy::Divisor div_seq_len, const size_t vocab_size,
|
||||||
|
const SampleFunc& sample_token, double prefill_start,
|
||||||
|
double gen_start, Activations& activations,
|
||||||
|
TokenStreamer& token_streamer, std::vector<int>& gen_tokens,
|
||||||
|
TimingInfo& timing_info,
|
||||||
|
const QueriesMutablePos& queries_mutable_pos) {
|
||||||
|
const size_t num_queries = queries_prompt.size();
|
||||||
|
// Decode generates one token per query and increments
|
||||||
|
// queries_mutable_pos.
|
||||||
|
Transformer(QueriesToken(gen_tokens.data(), num_queries), queries_mutable_pos,
|
||||||
|
queries_prefix_end, weights, activations, div_seq_len, kv_caches,
|
||||||
|
runtime_config.layers_output,
|
||||||
|
runtime_config.activations_observer);
|
||||||
|
// queries_pos are incremented by Transformer.
|
||||||
|
|
||||||
|
bool all_queries_eos = true;
|
||||||
|
{
|
||||||
|
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
||||||
|
// Compute logits from last layer activations.
|
||||||
|
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
||||||
|
ConstMatFromWeights(weights.embedder_input_embedding),
|
||||||
|
/*add=*/nullptr, *activations.env,
|
||||||
|
RowPtrFromBatch(activations.logits));
|
||||||
|
}
|
||||||
|
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
||||||
|
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
||||||
|
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
||||||
|
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits, vocab_size);
|
||||||
|
const TokenAndProb tp = sample_token(logits, vocab_size);
|
||||||
|
timing_info.NotifyGenerated(prefill_start, gen_start);
|
||||||
|
|
||||||
|
const bool is_eos =
|
||||||
|
token_streamer(query_idx_start + query_idx,
|
||||||
|
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
||||||
|
all_queries_eos &= is_eos;
|
||||||
|
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
|
||||||
|
}
|
||||||
|
return all_queries_eos;
|
||||||
|
}
|
||||||
|
|
||||||
// Generates one continuation for each query in `queries_prompt`, which is one
|
// Generates one continuation for each query in `queries_prompt`, which is one
|
||||||
// qbatch whose size is at most the `batch_size` passed to
|
// qbatch whose size is at most the `batch_size` passed to
|
||||||
// `activations.Allocate`.
|
// `activations.Allocate`.
|
||||||
|
|
@ -1310,37 +1359,11 @@ void GenerateT(const ModelWeightsStorage& model, Activations& activations,
|
||||||
const size_t vocab_size = model.Config().vocab_size;
|
const size_t vocab_size = model.Config().vocab_size;
|
||||||
const double gen_start = hwy::platform::Now();
|
const double gen_start = hwy::platform::Now();
|
||||||
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
for (size_t gen = 0; gen < max_generated_tokens; ++gen) {
|
||||||
// Decode generates one token per query and increments
|
bool all_queries_eos = DecodeStepT<T>(
|
||||||
// queries_mutable_pos.
|
weights, runtime_config, queries_prompt, query_idx_start, kv_caches,
|
||||||
Transformer(QueriesToken(gen_tokens.data(), num_queries),
|
queries_prefix_end, div_seq_len, vocab_size, sample_token,
|
||||||
queries_mutable_pos, queries_prefix_end, weights, activations,
|
prefill_start, gen_start, activations, token_streamer, gen_tokens,
|
||||||
div_seq_len, kv_caches, runtime_config.layers_output,
|
timing_info, queries_mutable_pos);
|
||||||
runtime_config.activations_observer);
|
|
||||||
// queries_pos are incremented by Transformer.
|
|
||||||
|
|
||||||
bool all_queries_eos = true;
|
|
||||||
{
|
|
||||||
PROFILER_ZONE("Gen.EmbeddingMatmul");
|
|
||||||
// Compute logits from last layer activations.
|
|
||||||
MatMul(ConstMatFromBatch(num_queries, activations.x),
|
|
||||||
ConstMatFromWeights(weights.embedder_input_embedding),
|
|
||||||
/*add=*/nullptr, *activations.env,
|
|
||||||
RowPtrFromBatch(activations.logits));
|
|
||||||
}
|
|
||||||
PROFILER_ZONE("Gen.Softcap+Sample+Stream");
|
|
||||||
for (size_t query_idx = 0; query_idx < num_queries; ++query_idx) {
|
|
||||||
float* HWY_RESTRICT logits = activations.logits.Batch(query_idx);
|
|
||||||
MaybeLogitsSoftCap(weights.weights_config.final_cap, logits,
|
|
||||||
vocab_size);
|
|
||||||
const TokenAndProb tp = sample_token(logits, vocab_size);
|
|
||||||
timing_info.NotifyGenerated(prefill_start, gen_start);
|
|
||||||
|
|
||||||
const bool is_eos =
|
|
||||||
token_streamer(query_idx_start + query_idx,
|
|
||||||
queries_mutable_pos[query_idx], tp.token, tp.prob);
|
|
||||||
all_queries_eos &= is_eos;
|
|
||||||
gen_tokens[query_idx] = is_eos ? runtime_config.eos_id : tp.token;
|
|
||||||
}
|
|
||||||
if (all_queries_eos) break;
|
if (all_queries_eos) break;
|
||||||
} // foreach token to generate
|
} // foreach token to generate
|
||||||
timing_info.NotifyGenerateDone(gen_start);
|
timing_info.NotifyGenerateDone(gen_start);
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue