graph : make the compute graph constant with respect to active samplers

This commit is contained in:
Georgi Gerganov 2025-12-10 15:54:33 +02:00
parent 0ecee8be37
commit c02654eb7d
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
3 changed files with 48 additions and 36 deletions

View File

@ -1241,7 +1241,10 @@ static void copy_tensor_async_ints(
for (const auto & [seq_id, tensor] : tensor_map) { for (const auto & [seq_id, tensor] : tensor_map) {
auto it = seq_to_row.find(seq_id); auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end()); if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second; const uint32_t row = it->second;
GGML_ASSERT(row < sampled_size); GGML_ASSERT(row < sampled_size);
@ -1265,7 +1268,10 @@ static void copy_tensor_async_floats(
for (const auto & [seq_id, tensor] : tensor_map) { for (const auto & [seq_id, tensor] : tensor_map) {
auto it = seq_to_row.find(seq_id); auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end()); if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second; const uint32_t row = it->second;
GGML_ASSERT(row < counts.size()); GGML_ASSERT(row < counts.size());
@ -1293,7 +1299,10 @@ static void copy_tensor_async_candidates(
for (const auto & [seq_id, tensor] : tensor_map) { for (const auto & [seq_id, tensor] : tensor_map) {
auto it = seq_to_row.find(seq_id); auto it = seq_to_row.find(seq_id);
GGML_ASSERT(it != seq_to_row.end()); if (it == seq_to_row.end()) {
continue;
}
const uint32_t row = it->second; const uint32_t row = it->second;
GGML_ASSERT(row < counts.size()); GGML_ASSERT(row < counts.size());

View File

@ -12,6 +12,7 @@
#include <cassert> #include <cassert>
#include <cmath> #include <cmath>
#include <cstring> #include <cstring>
#include <unordered_set>
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
if (ubatch->token) { if (ubatch->token) {
@ -466,8 +467,22 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
} }
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
GGML_UNUSED(ubatch); // set the inputs only for the active samplers in the current ubatch
for (const auto & [seq_id, sampler] : samplers) { std::unordered_set<llama_seq_id> active_samplers;
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
if (ubatch->output[i]) {
llama_seq_id seq_id = ubatch->seq_id[i][0];
active_samplers.insert(seq_id);
}
}
for (auto seq_id : active_samplers) {
if (samplers.find(seq_id) == samplers.end()) {
continue;
}
auto & sampler = samplers[seq_id];
if (sampler->iface->backend_set_input) { if (sampler->iface->backend_set_input) {
sampler->iface->backend_set_input(sampler); sampler->iface->backend_set_input(sampler);
} }
@ -475,11 +490,10 @@ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
} }
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
if (params.samplers.empty()) { if (samplers.size() != params.samplers.size()) {
return true; return false;
} }
// TODO: this check is incorrect - it has to check against the last set of samplers that were used in the previous graph
for (const auto & [seq_id, sampler] : params.samplers) { for (const auto & [seq_id, sampler] : params.samplers) {
if (samplers[seq_id] != sampler) { if (samplers[seq_id] != sampler) {
return false; return false;
@ -1830,8 +1844,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask); ggml_set_input(inp->self_kq_mask);
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
} }
{ {
@ -1844,8 +1860,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream); inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
ggml_set_input(inp->self_kq_mask_swa); ggml_set_input(inp->self_kq_mask_swa);
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
} }
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
@ -2084,6 +2102,9 @@ void llm_graph_context::build_sampling() const {
return; return;
} }
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
res->add_input(std::move(inp_sampling));
std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row; std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row;
int32_t logit_row_idx = 0; int32_t logit_row_idx = 0;
@ -2095,30 +2116,21 @@ void llm_graph_context::build_sampling() const {
} }
} }
if (seq_to_logit_row.empty()) {
return;
}
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1) // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
ggml_tensor * logits_t = res->t_logits;
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
const int64_t n_vocab = logits_t->ne[0]; // add a dummy row of logits
// this trick makes the graph static, regardless of which samplers are activated
std::unordered_map<llama_seq_id, llama_sampler *> active_samplers; // this is important in order to minimize graph reallocations
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
for (const auto & [seq_id, sampler] : samplers) { for (const auto & [seq_id, sampler] : samplers) {
// Only process samplers for sequences that are in the current batch const auto it = seq_to_logit_row.find(seq_id);
auto it = seq_to_logit_row.find(seq_id);
if (it == seq_to_logit_row.end()) {
continue;
}
active_samplers[seq_id] = sampler; // inactive samplers alawys work on the first row
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
const int32_t row_idx = it->second; ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
ggml_format_name(logits_seq, "logits_seq_%d", seq_id); ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
struct llama_sampler_data data = { struct llama_sampler_data data = {
@ -2163,9 +2175,6 @@ void llm_graph_context::build_sampling() const {
} }
} }
*/ */
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(n_vocab, false, active_samplers);
res->add_input(std::move(inp_sampling));
} }
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {

View File

@ -385,19 +385,13 @@ public:
class llm_graph_input_sampling : public llm_graph_input_i { class llm_graph_input_sampling : public llm_graph_input_i {
public: public:
llm_graph_input_sampling(int32_t n_vocab, bool sorted, llm_graph_input_sampling(std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
std::unordered_map<llama_seq_id, llama_sampler *> samplers) : samplers(std::move(samplers)) { }
n_vocab(n_vocab), sorted_value(sorted), samplers(std::move(samplers)) { }
virtual ~llm_graph_input_sampling() = default; virtual ~llm_graph_input_sampling() = default;
void set_input(const llama_ubatch * ubatch) override; void set_input(const llama_ubatch * ubatch) override;
bool can_reuse(const llm_graph_params & params) override; bool can_reuse(const llm_graph_params & params) override;
int32_t n_vocab;
bool sorted_value;
ggml_tensor * size = nullptr; // I32 [1]
ggml_tensor * sorted = nullptr; // I32 [1]
std::unordered_map<llama_seq_id, llama_sampler *> samplers; std::unordered_map<llama_seq_id, llama_sampler *> samplers;
}; };