diff --git a/common/speculative.cpp b/common/speculative.cpp index c99b19dbfd..84d2556ceb 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -805,6 +805,42 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } +bool common_speculative_is_compat(llama_context * ctx_tgt) { + auto * mem = llama_get_memory(ctx_tgt); + if (mem == nullptr) { + return false; + } + + bool res = true; + + llama_memory_clear(mem, true); + + // eval 2 tokens to check if the context is compatible + std::vector tmp; + tmp.push_back(0); + tmp.push_back(0); + + int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); + if (ret != 0) { + LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); + res = false; + goto done; + } + + // try to remove the last tokens + if (!llama_memory_seq_rm(mem, 0, 1, -1)) { + LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); + res = false; + goto done; + } + +done: + llama_memory_clear(mem, true); + llama_synchronize(ctx_tgt); + + return res; +} + // initialization of the speculative decoding system // common_speculative * common_speculative_init( diff --git a/common/speculative.h b/common/speculative.h index 76fe6bb7bc..876cde3d18 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -14,6 +14,10 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); +// check if the llama_context is compatible for speculative decoding +// note: clears the memory of the context +bool common_speculative_is_compat(llama_context * ctx_tgt); + common_speculative * common_speculative_init( common_params_speculative & params, llama_context * ctx_tgt); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 7f9c3c566b..b71d496eeb 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -740,6 +740,11 @@ private: slots.clear(); + const bool can_spec = common_speculative_is_compat(ctx); + if (!can_spec) { + SRV_WRN("%s", "speculative decoding not supported by this context\n"); + } + // initialize slots for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; @@ -752,7 +757,7 @@ private: slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - { + if (can_spec) { slot.spec = common_speculative_init(params_base.speculative, slot.ctx); if (slot.spec) { if (mctx) {