From c02654eb7d1fc2afe0dfdc17434bb2e3ec8efed4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Wed, 10 Dec 2025 15:54:33 +0200 Subject: [PATCH] graph : make the compute graph constant with respect to active samplers --- src/llama-context.cpp | 15 ++++++++--- src/llama-graph.cpp | 59 +++++++++++++++++++++++++------------------ src/llama-graph.h | 10 ++------ 3 files changed, 48 insertions(+), 36 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 73432e5d04..a4d332b114 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1241,7 +1241,10 @@ static void copy_tensor_async_ints( for (const auto & [seq_id, tensor] : tensor_map) { auto it = seq_to_row.find(seq_id); - GGML_ASSERT(it != seq_to_row.end()); + if (it == seq_to_row.end()) { + continue; + } + const uint32_t row = it->second; GGML_ASSERT(row < sampled_size); @@ -1265,7 +1268,10 @@ static void copy_tensor_async_floats( for (const auto & [seq_id, tensor] : tensor_map) { auto it = seq_to_row.find(seq_id); - GGML_ASSERT(it != seq_to_row.end()); + if (it == seq_to_row.end()) { + continue; + } + const uint32_t row = it->second; GGML_ASSERT(row < counts.size()); @@ -1293,7 +1299,10 @@ static void copy_tensor_async_candidates( for (const auto & [seq_id, tensor] : tensor_map) { auto it = seq_to_row.find(seq_id); - GGML_ASSERT(it != seq_to_row.end()); + if (it == seq_to_row.end()) { + continue; + } + const uint32_t row = it->second; GGML_ASSERT(row < counts.size()); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index b1967f6395..0261b65bc4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -12,6 +12,7 @@ #include #include #include +#include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { @@ -466,8 +467,22 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) { } void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { - GGML_UNUSED(ubatch); - for (const auto & [seq_id, sampler] : samplers) { + // set the inputs only for the active samplers in the current ubatch + std::unordered_set active_samplers; + for (uint32_t i = 0; i < ubatch->n_tokens; i++) { + if (ubatch->output[i]) { + llama_seq_id seq_id = ubatch->seq_id[i][0]; + active_samplers.insert(seq_id); + } + } + + for (auto seq_id : active_samplers) { + if (samplers.find(seq_id) == samplers.end()) { + continue; + } + + auto & sampler = samplers[seq_id]; + if (sampler->iface->backend_set_input) { sampler->iface->backend_set_input(sampler); } @@ -475,11 +490,10 @@ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { } bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { - if (params.samplers.empty()) { - return true; + if (samplers.size() != params.samplers.size()) { + return false; } - // TODO: this check is incorrect - it has to check against the last set of samplers that were used in the previous graph for (const auto & [seq_id, sampler] : params.samplers) { if (samplers[seq_id] != sampler) { return false; @@ -1830,8 +1844,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask); + ggml_set_name(inp->self_kq_mask, "self_kq_mask"); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -1844,8 +1860,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); + ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); @@ -2084,6 +2102,9 @@ void llm_graph_context::build_sampling() const { return; } + auto inp_sampling = std::make_unique(samplers); + res->add_input(std::move(inp_sampling)); + std::unordered_map seq_to_logit_row; int32_t logit_row_idx = 0; @@ -2095,30 +2116,21 @@ void llm_graph_context::build_sampling() const { } } - if (seq_to_logit_row.empty()) { - return; - } - // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1) - ggml_tensor * logits_t = res->t_logits; GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); - const int64_t n_vocab = logits_t->ne[0]; - - std::unordered_map active_samplers; + // add a dummy row of logits + // this trick makes the graph static, regardless of which samplers are activated + // this is important in order to minimize graph reallocations + ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); for (const auto & [seq_id, sampler] : samplers) { - // Only process samplers for sequences that are in the current batch - auto it = seq_to_logit_row.find(seq_id); - if (it == seq_to_logit_row.end()) { - continue; - } + const auto it = seq_to_logit_row.find(seq_id); - active_samplers[seq_id] = sampler; + // inactive samplers alawys work on the first row + const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; - const int32_t row_idx = it->second; - - ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]); + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); ggml_format_name(logits_seq, "logits_seq_%d", seq_id); struct llama_sampler_data data = { @@ -2163,9 +2175,6 @@ void llm_graph_context::build_sampling() const { } } */ - - auto inp_sampling = std::make_unique(n_vocab, false, active_samplers); - res->add_input(std::move(inp_sampling)); } int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 006cae3c84..490d4fb00c 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -385,19 +385,13 @@ public: class llm_graph_input_sampling : public llm_graph_input_i { public: - llm_graph_input_sampling(int32_t n_vocab, bool sorted, - std::unordered_map samplers) : - n_vocab(n_vocab), sorted_value(sorted), samplers(std::move(samplers)) { } + llm_graph_input_sampling(std::unordered_map samplers) : + samplers(std::move(samplers)) { } virtual ~llm_graph_input_sampling() = default; void set_input(const llama_ubatch * ubatch) override; bool can_reuse(const llm_graph_params & params) override; - int32_t n_vocab; - bool sorted_value; - ggml_tensor * size = nullptr; // I32 [1] - ggml_tensor * sorted = nullptr; // I32 [1] - std::unordered_map samplers; };