graph : make the compute graph constant with respect to active samplers

This commit is contained in:
Georgi Gerganov 2025-12-10 15:54:33 +02:00
parent 0ecee8be37
commit c02654eb7d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 48 additions and 36 deletions

View File

@ -1241,7 +1241,10 @@ static void copy_tensor_async_ints(
for (const auto & [seq_id, tensor] : tensor_map) {
auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end());
if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second;
GGML_ASSERT(row < sampled_size);
@ -1265,7 +1268,10 @@ static void copy_tensor_async_floats(
for (const auto & [seq_id, tensor] : tensor_map) {
auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end());
if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second;
GGML_ASSERT(row < counts.size());
@ -1293,7 +1299,10 @@ static void copy_tensor_async_candidates(
for (const auto & [seq_id, tensor] : tensor_map) {
auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end());
if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second;
GGML_ASSERT(row < counts.size());

View File

@ -12,6 +12,7 @@
#include <cassert>
#include <cmath>
#include <cstring>
#include <unordered_set>
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) {
@ -466,8 +467,22 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
}
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch);
for (const auto & [seq_id, sampler] : samplers) {
// set the inputs only for the active samplers in the current ubatch
std::unordered_set<llama_seq_id> active_samplers;
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
if (ubatch->output[i]) {
llama_seq_id seq_id = ubatch->seq_id[i][0];
active_samplers.insert(seq_id);
}
}
for (auto seq_id : active_samplers) {
if (samplers.find(seq_id) == samplers.end()) {
continue;
}
auto & sampler = samplers[seq_id];
if (sampler->iface->backend_set_input) {
sampler->iface->backend_set_input(sampler);
}
@ -475,11 +490,10 @@ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
}
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
if (params.samplers.empty()) {
return true;
if (samplers.size() != params.samplers.size()) {
return false;
}
// TODO: this check is incorrect - it has to check against the last set of samplers that were used in the previous graph
for (const auto & [seq_id, sampler] : params.samplers) {
if (samplers[seq_id] != sampler) {
return false;
@ -1830,8 +1844,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
}
{
@ -1844,8 +1860,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
}
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
@ -2084,6 +2102,9 @@ void llm_graph_context::build_sampling() const {
return;
}
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));
std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row;
int32_t logit_row_idx = 0;
@ -2095,30 +2116,21 @@ void llm_graph_context::build_sampling() const {
}
}
if (seq_to_logit_row.empty()) {
return;
}
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
ggml_tensor * logits_t = res->t_logits;
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
const int64_t n_vocab = logits_t->ne[0];
std::unordered_map<llama_seq_id, llama_sampler *> active_samplers;
// add a dummy row of logits
// this trick makes the graph static, regardless of which samplers are activated
// this is important in order to minimize graph reallocations
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
for (const auto & [seq_id, sampler] : samplers) {
// Only process samplers for sequences that are in the current batch
auto it = seq_to_logit_row.find(seq_id);
if (it == seq_to_logit_row.end()) {
continue;
}
const auto it = seq_to_logit_row.find(seq_id);
active_samplers[seq_id] = sampler;
// inactive samplers alawys work on the first row
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
const int32_t row_idx = it->second;
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
struct llama_sampler_data data = {
@ -2163,9 +2175,6 @@ void llm_graph_context::build_sampling() const {
}
}
*/
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(n_vocab, false, active_samplers);
res->add_input(std::move(inp_sampling));
}
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {

View File

@ -385,19 +385,13 @@ public:
class llm_graph_input_sampling : public llm_graph_input_i {
public:
llm_graph_input_sampling(int32_t n_vocab, bool sorted,
std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
n_vocab(n_vocab), sorted_value(sorted), samplers(std::move(samplers)) { }
llm_graph_input_sampling(std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
samplers(std::move(samplers)) { }
virtual ~llm_graph_input_sampling() = default;
void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override;
int32_t n_vocab;
bool sorted_value;
ggml_tensor * size = nullptr; // I32 [1]
ggml_tensor * sorted = nullptr; // I32 [1]
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
};