llama : enable static graph for multiple sampling outputs per sequence
This commit makes the computation graph static when backend samplers process multiple outputs per sequence. Previously, only active samplers, those with outputs in the current batch, were added to the graph. This could cause graph reallocations if different samplers become active/inactive across batches, even when the number of outputs remained constant.
This commit is contained in:
parent
1e8c02aa95
commit
765998f2d7
|
|
@ -1331,7 +1331,6 @@ static void copy_tensor_async_ints(
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<uint32_t> & rows = it->second;
|
const std::vector<uint32_t> & rows = it->second;
|
||||||
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
const uint32_t row = rows[i];
|
const uint32_t row = rows[i];
|
||||||
|
|
@ -1364,7 +1363,6 @@ static void copy_tensor_async_floats(
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<uint32_t> & rows = it->second;
|
const std::vector<uint32_t> & rows = it->second;
|
||||||
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
const uint32_t row = rows[i];
|
const uint32_t row = rows[i];
|
||||||
|
|
@ -1401,7 +1399,6 @@ static void copy_tensor_async_candidates(
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::vector<uint32_t> & rows = it->second;
|
const std::vector<uint32_t> & rows = it->second;
|
||||||
GGML_ASSERT(tensors.size() == rows.size() && "number of tensors must match number of output rows");
|
|
||||||
|
|
||||||
for (size_t i = 0; i < tensors.size(); ++i) {
|
for (size_t i = 0; i < tensors.size(); ++i) {
|
||||||
const uint32_t row = rows[i];
|
const uint32_t row = rows[i];
|
||||||
|
|
|
||||||
|
|
@ -2605,17 +2605,18 @@ void llm_graph_context::build_sampling() const {
|
||||||
|
|
||||||
for (const auto & [seq_id, sampler] : samplers) {
|
for (const auto & [seq_id, sampler] : samplers) {
|
||||||
const auto row_it = seq_to_logit_rows.find(seq_id);
|
const auto row_it = seq_to_logit_rows.find(seq_id);
|
||||||
|
const bool sampler_is_active = row_it != seq_to_logit_rows.end();
|
||||||
|
|
||||||
// row_it is now a sequence id to list of row ids
|
// Always build samplers for all possible outputs even if the sampler is
|
||||||
static const std::vector<int32_t> default_row = {0};
|
// not active (the sampler's sequence id is not in the current ubatch).
|
||||||
const std::vector<int32_t> & logit_rows = row_it != seq_to_logit_rows.end() ? row_it->second : default_row;
|
for (uint32_t i = 0; i < max_outputs; ++i) {
|
||||||
for (const int32_t row_idx : logit_rows) {
|
const bool real_output = sampler_is_active && i < row_it->second.size();
|
||||||
|
|
||||||
// inactive samplers always work on the first row
|
const int32_t row_idx = real_output ? row_it->second[i] : 0;
|
||||||
const int i_out = row_it != seq_to_logit_rows.end() ? 1 : 0;
|
const int i_out = real_output ? 1 : 0;
|
||||||
|
|
||||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
ggml_format_name(logits_seq, "logits_seq_%d_%d", seq_id, i);
|
||||||
|
|
||||||
struct llama_sampler_data data = {
|
struct llama_sampler_data data = {
|
||||||
/*.logits =*/ logits_seq,
|
/*.logits =*/ logits_seq,
|
||||||
|
|
@ -2628,25 +2629,33 @@ void llm_graph_context::build_sampling() const {
|
||||||
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
sampler->iface->backend_apply(sampler, ctx0, gf, &data);
|
||||||
|
|
||||||
if (data.sampled != nullptr) {
|
if (data.sampled != nullptr) {
|
||||||
res->t_sampled[seq_id].push_back(data.sampled);
|
if (real_output) {
|
||||||
|
res->t_sampled[seq_id].push_back(data.sampled);
|
||||||
|
}
|
||||||
outs[1] = data.sampled;
|
outs[1] = data.sampled;
|
||||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.probs != nullptr) {
|
if (data.probs != nullptr) {
|
||||||
res->t_sampled_probs[seq_id].push_back(data.probs);
|
if (real_output) {
|
||||||
|
res->t_sampled_probs[seq_id].push_back(data.probs);
|
||||||
|
}
|
||||||
outs[1] = data.probs;
|
outs[1] = data.probs;
|
||||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.logits != nullptr) {
|
if (data.logits != nullptr) {
|
||||||
res->t_sampled_logits[seq_id].push_back(data.logits);
|
if (real_output) {
|
||||||
|
res->t_sampled_logits[seq_id].push_back(data.logits);
|
||||||
|
}
|
||||||
outs[1] = data.logits;
|
outs[1] = data.logits;
|
||||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (data.candidates != nullptr) {
|
if (data.candidates != nullptr) {
|
||||||
res->t_candidates[seq_id].push_back(data.candidates);
|
if (real_output) {
|
||||||
|
res->t_candidates[seq_id].push_back(data.candidates);
|
||||||
|
}
|
||||||
outs[1] = data.candidates;
|
outs[1] = data.candidates;
|
||||||
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
ggml_build_forward_select(gf, outs.data(), outs.size(), i_out);
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue