llama : reserve graphs with samplers

This commit is contained in:
Georgi Gerganov 2025-11-29 23:57:25 +02:00
parent 467746e3ad
commit 1760bd69b3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 12 additions and 3 deletions

View File

@ -1903,6 +1903,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
// set one output token per sequence in order to activate all backend samplers
std::vector<llama_seq_id> seq_ids(n_seqs);
for (uint32_t i = 0; i < n_seqs; ++i) {
seq_ids[i] = i;
ubatch.n_seq_id[i] = 1;
ubatch.seq_id[i] = &seq_ids[i];
ubatch.output[i] = true;
}
auto * res = gf_res_reserve.get(); auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);

View File

@ -2063,18 +2063,18 @@ void llm_graph_context::build_sampling() const {
logit_row_idx++; logit_row_idx++;
} }
} }
if (seq_to_logit_row.empty()) { if (seq_to_logit_row.empty()) {
return; return;
} }
// res->t_logits will contain logits for all tokens that specied that want // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
// logits calculated (logits=1 or output=1)
ggml_tensor * logits_t = res->t_logits; ggml_tensor * logits_t = res->t_logits;
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
const int64_t n_vocab = logits_t->ne[0]; const int64_t n_vocab = logits_t->ne[0];
std::unordered_map<llama_seq_id, llama_sampler*> active_samplers; std::unordered_map<llama_seq_id, llama_sampler *> active_samplers;
for (const auto & [seq_id, sampler] : samplers) { for (const auto & [seq_id, sampler] : samplers) {
// Only process samplers for sequences that are in the current batch // Only process samplers for sequences that are in the current batch