common : add common_speculative_is_compat() (#19270)

* llama : add llama_memory_can_rm_suffix()

* Revert "llama : add llama_memory_can_rm_suffix()"

This reverts commit d30e59b62a.

* spec : check if the target context is compatible for spec decoding
This commit is contained in:
Georgi Gerganov 2026-02-06 16:47:22 +02:00 committed by GitHub
parent 06bf3796f4
commit dfde5993ea
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 46 additions and 1 deletions

View File

@ -805,6 +805,42 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}
bool common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return false;
}
bool res = true;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = false;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx_tgt);
return res;
}
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(

View File

@ -14,6 +14,10 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
bool common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);

View File

@ -740,6 +740,11 @@ private:
slots.clear();
const bool can_spec = common_speculative_is_compat(ctx);
if (!can_spec) {
SRV_WRN("%s", "speculative decoding not supported by this context\n");
}
// initialize slots
for (int i = 0; i < params_base.n_parallel; i++) {
server_slot slot;
@ -752,7 +757,7 @@ private:
slot.prompt.tokens.has_mtmd = mctx != nullptr;
// try speculative decoding
{
if (can_spec) {
slot.spec = common_speculative_init(params_base.speculative, slot.ctx);
if (slot.spec) {
if (mctx) {