llama : reserve graphs with samplers

This commit is contained in:
Georgi Gerganov 2025-11-29 23:57:25 +02:00
parent 467746e3ad
commit 1760bd69b3
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
2 changed files with 12 additions and 3 deletions

View File

@ -1903,6 +1903,15 @@ ggml_cgraph * llama_context::graph_reserve(uint32_t n_tokens, uint32_t n_seqs, u
llama_batch_allocr balloc(model.hparams.n_pos_per_embd());
llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs);
// set one output token per sequence in order to activate all backend samplers
std::vector<llama_seq_id> seq_ids(n_seqs);
for (uint32_t i = 0; i < n_seqs; ++i) {
seq_ids[i] = i;
ubatch.n_seq_id[i] = 1;
ubatch.seq_id[i] = &seq_ids[i];
ubatch.output[i] = true;
}
auto * res = gf_res_reserve.get();
const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT);

View File

@ -2063,12 +2063,12 @@ void llm_graph_context::build_sampling() const {
logit_row_idx++;
}
}
if (seq_to_logit_row.empty()) {
return;
}
// res->t_logits will contain logits for all tokens that specied that want
// logits calculated (logits=1 or output=1)
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
ggml_tensor * logits_t = res->t_logits;
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");