From 0da7e7dcccfac8a75bf3f65ac54cb4ea6b200c56 Mon Sep 17 00:00:00 2001 From: Daniel Bevenius Date: Wed, 19 Nov 2025 06:59:03 +0100 Subject: [PATCH] sampling : remove version from sampler chain This commit removes the version field from the sampler chain and instead used the sampler pointer itself for change detection. --- include/llama.h | 1 - src/llama-graph.cpp | 2 +- src/llama-graph.h | 10 +--------- src/llama-sampling.cpp | 7 ------- src/llama-sampling.h | 3 --- 5 files changed, 2 insertions(+), 21 deletions(-) diff --git a/include/llama.h b/include/llama.h index cbf23c7bcf..2bc41f36d8 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1235,7 +1235,6 @@ extern "C" { // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i); - LLAMA_API uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain); // available samplers: diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 561e629869..fbf4475aa4 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -477,7 +477,7 @@ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { } for (const auto & [seq_id, sampler] : params.samplers) { - if (sampler_versions[seq_id] != llama_sampler_chain_get_version(sampler)) { + if (samplers[seq_id] != sampler) { return false; } } diff --git a/src/llama-graph.h b/src/llama-graph.h index 552c3e724f..f7508046e4 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -387,13 +387,7 @@ class llm_graph_input_sampling : public llm_graph_input_i { public: llm_graph_input_sampling(int32_t n_vocab, bool sorted, std::unordered_map samplers) : - n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) { - - sampler_versions.reserve(samplers.size()); - for (const auto & [seq_id, sampler] : samplers) { - sampler_versions[seq_id] = llama_sampler_chain_get_version(sampler); - } - } + n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) { } virtual ~llm_graph_input_sampling() = default; void set_input(const llama_ubatch * ubatch) override; @@ -404,8 +398,6 @@ public: ggml_tensor * size = nullptr; // I32 [1] ggml_tensor * sorted = nullptr; // I32 [1] - // Track sampler chain version for reuse - std::unordered_map sampler_versions; std::unordered_map samplers; }; diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index d210b826c7..456e050201 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -639,7 +639,6 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; p->samplers.push_back(smpl); - p->version++; } struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { @@ -661,7 +660,6 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, auto * result = p->samplers[i]; p->samplers.erase(p->samplers.begin() + i); - p->version++; return result; } @@ -672,11 +670,6 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { return p->samplers.size(); } -uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain) { - const auto * p = (const llama_sampler_chain *) chain->ctx; - return p->version; -} - // // samplers // diff --git a/src/llama-sampling.h b/src/llama-sampling.h index d92311f58a..759dd7dcb7 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -21,9 +21,6 @@ struct llama_sampler_chain { mutable int64_t t_sample_us; mutable int32_t n_sample; - - // simple version tracking for GPU sampling graph can_reuse - uint64_t version = 0; }; struct llama_sampler * llama_sampler_init_dry_testing(