llama : enable static graph for multiple sampling outputs per sequence

This commit makes the computation graph static when backend samplers
process multiple outputs per sequence.

Previously, only active samplers, those with outputs in the current
batch, were added to the graph. This could cause graph reallocations if
different samplers become active/inactive across batches, even when the
number of outputs remained constant.
This commit is contained in:
Daniel Bevenius 2026-02-26 11:16:23 +01:00
parent 1e8c02aa95
commit 765998f2d7
No known key found for this signature in database
2 changed files with 20 additions and 14 deletions

View File

@ -1331,7 +1331,6 @@ static void copy_tensor_async_ints(
} }
const std::vector<uint32_t> & rows = it->second; const std::vector<uint32_t> & rows = it->second;
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
const uint32_t row = rows[i]; const uint32_t row = rows[i];
@ -1364,7 +1363,6 @@ static void copy_tensor_async_floats(
} }
const std::vector<uint32_t> & rows = it->second; const std::vector<uint32_t> & rows = it->second;
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
const uint32_t row = rows[i]; const uint32_t row = rows[i];
@ -1401,7 +1399,6 @@ static void copy_tensor_async_candidates(
} }
const std::vector<uint32_t> & rows = it->second; const std::vector<uint32_t> & rows = it->second;
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
for (size_t i = 0; i < tensors.size(); ++i) { for (size_t i = 0; i < tensors.size(); ++i) {
const uint32_t row = rows[i]; const uint32_t row = rows[i];

View File

@ -2605,17 +2605,18 @@ void llm_graph_context::build_sampling() const {
for (const auto & [seq_id, sampler] : samplers) { for (const auto & [seq_id, sampler] : samplers) {
const auto row_it = seq_to_logit_rows.find(seq_id); const auto row_it = seq_to_logit_rows.find(seq_id);
const bool sampler_is_active = row_it != seq_to_logit_rows.end();
// row_it is now a sequence id to list of row ids // Always build samplers for all possible outputs even if the sampler is
static const std::vector<int32_t> default_row = {0}; // not active (the sampler's sequence id is not in the current ubatch).
const std::vector<int32_t> & logit_rows = row_it != seq_to_logit_rows.end() ? row_it->second : default_row; for (uint32_t i = 0; i < max_outputs; ++i) {
for (const int32_t row_idx : logit_rows) { const bool real_output = sampler_is_active && i < row_it->second.size();
// inactive samplers always work on the first row const int32_t row_idx = real_output ? row_it->second[i] : 0;
const int i_out = row_it != seq_to_logit_rows.end() ? 1 : 0; const int i_out = real_output ? 1 : 0;
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id); ggml_format_name(logits_seq, "logits_seq_%d_%d", seq_id, i);
struct llama_sampler_data data = { struct llama_sampler_data data = {
/*.logits =*/ logits_seq, /*.logits =*/ logits_seq,
@ -2628,25 +2629,33 @@ void llm_graph_context::build_sampling() const {
sampler->iface->backend_apply(sampler, ctx0, gf, &data); sampler->iface->backend_apply(sampler, ctx0, gf, &data);
if (data.sampled != nullptr) { if (data.sampled != nullptr) {
res->t_sampled[seq_id].push_back(data.sampled); if (real_output) {
res->t_sampled[seq_id].push_back(data.sampled);
}
outs[1] = data.sampled; outs[1] = data.sampled;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
} }
if (data.probs != nullptr) { if (data.probs != nullptr) {
res->t_sampled_probs[seq_id].push_back(data.probs); if (real_output) {
res->t_sampled_probs[seq_id].push_back(data.probs);
}
outs[1] = data.probs; outs[1] = data.probs;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
} }
if (data.logits != nullptr) { if (data.logits != nullptr) {
res->t_sampled_logits[seq_id].push_back(data.logits); if (real_output) {
res->t_sampled_logits[seq_id].push_back(data.logits);
}
outs[1] = data.logits; outs[1] = data.logits;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
} }
if (data.candidates != nullptr) { if (data.candidates != nullptr) {
res->t_candidates[seq_id].push_back(data.candidates); if (real_output) {
res->t_candidates[seq_id].push_back(data.candidates);
}
outs[1] = data.candidates; outs[1] = data.candidates;
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out); ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
} }