sampling : remove version from sampler chain
This commit removes the version field from the sampler chain and instead used the sampler pointer itself for change detection.
This commit is contained in:
parent
26be108be8
commit
0da7e7dccc
|
|
@ -1235,7 +1235,6 @@ extern "C" {
|
||||||
|
|
||||||
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
// after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
|
LLAMA_API struct llama_sampler * llama_sampler_chain_remove( struct llama_sampler * chain, int32_t i);
|
||||||
LLAMA_API uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain);
|
|
||||||
|
|
||||||
// available samplers:
|
// available samplers:
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -477,7 +477,7 @@ bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
for (const auto & [seq_id, sampler] : params.samplers) {
|
for (const auto & [seq_id, sampler] : params.samplers) {
|
||||||
if (sampler_versions[seq_id] != llama_sampler_chain_get_version(sampler)) {
|
if (samplers[seq_id] != sampler) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -387,13 +387,7 @@ class llm_graph_input_sampling : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_sampling(int32_t n_vocab, bool sorted,
|
llm_graph_input_sampling(int32_t n_vocab, bool sorted,
|
||||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers) :
|
std::unordered_map<llama_seq_id, llama_sampler*> samplers) :
|
||||||
n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) {
|
n_vocab(n_vocab), sorted_value(sorted), samplers(samplers) { }
|
||||||
|
|
||||||
sampler_versions.reserve(samplers.size());
|
|
||||||
for (const auto & [seq_id, sampler] : samplers) {
|
|
||||||
sampler_versions[seq_id] = llama_sampler_chain_get_version(sampler);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
virtual ~llm_graph_input_sampling() = default;
|
virtual ~llm_graph_input_sampling() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
|
|
@ -404,8 +398,6 @@ public:
|
||||||
ggml_tensor * size = nullptr; // I32 [1]
|
ggml_tensor * size = nullptr; // I32 [1]
|
||||||
ggml_tensor * sorted = nullptr; // I32 [1]
|
ggml_tensor * sorted = nullptr; // I32 [1]
|
||||||
|
|
||||||
// Track sampler chain version for reuse
|
|
||||||
std::unordered_map<llama_seq_id, uint64_t> sampler_versions;
|
|
||||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -639,7 +639,6 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
||||||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||||
auto * p = (llama_sampler_chain *) chain->ctx;
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
||||||
p->samplers.push_back(smpl);
|
p->samplers.push_back(smpl);
|
||||||
p->version++;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||||
|
|
@ -661,7 +660,6 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
|
||||||
|
|
||||||
auto * result = p->samplers[i];
|
auto * result = p->samplers[i];
|
||||||
p->samplers.erase(p->samplers.begin() + i);
|
p->samplers.erase(p->samplers.begin() + i);
|
||||||
p->version++;
|
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
@ -672,11 +670,6 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) {
|
||||||
return p->samplers.size();
|
return p->samplers.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
uint64_t llama_sampler_chain_get_version(const struct llama_sampler * chain) {
|
|
||||||
const auto * p = (const llama_sampler_chain *) chain->ctx;
|
|
||||||
return p->version;
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// samplers
|
// samplers
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -21,9 +21,6 @@ struct llama_sampler_chain {
|
||||||
mutable int64_t t_sample_us;
|
mutable int64_t t_sample_us;
|
||||||
|
|
||||||
mutable int32_t n_sample;
|
mutable int32_t n_sample;
|
||||||
|
|
||||||
// simple version tracking for GPU sampling graph can_reuse
|
|
||||||
uint64_t version = 0;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_init_dry_testing(
|
struct llama_sampler * llama_sampler_init_dry_testing(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue