diff --git a/common/arg.cpp b/common/arg.cpp index 44dbf84e61..538d2a4b0a 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -3538,20 +3538,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_min_hits = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(common_arg( - {"--spec-use-checkpoints"}, "[on|off|auto]", - string_format("use checkpoints to rewind token history in recurrent models ('on', 'off', or 'auto', default: %s)", - params.speculative.use_checkpoints ? "on" : "off"), - [](common_params & params, const std::string & value) { - if (is_truthy(value) || is_autoy(value)) { - params.speculative.use_checkpoints = true; - } else if (is_falsey(value)) { - params.speculative.use_checkpoints = false; - } else { - throw std::invalid_argument("invalid value for --spec-use-checkpoints"); - } - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( diff --git a/common/common.h b/common/common.h index 8619394097..70203bc8f8 100644 --- a/common/common.h +++ b/common/common.h @@ -326,7 +326,6 @@ struct common_params_speculative { uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed bool use_checkpoints = false; // use checkpoints to rewind in token history of recurrent models - std::shared_ptr ngram_mod; std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT diff --git a/common/speculative.cpp b/common/speculative.cpp index a863690a64..124c9c9120 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -910,13 +910,13 @@ enum common_speculative_type common_speculative_type_from_name(const std::string return it->second; } -bool common_speculative_is_compat(llama_context * ctx_tgt) { +common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt) { auto * mem = llama_get_memory(ctx_tgt); if (mem == nullptr) { - return false; + return COMMON_SPECULATIVE_COMPAT_TYPE_NO; } - bool res = true; + common_speculative_compat_type res = COMMON_SPECULATIVE_COMPAT_TYPE_FULL; llama_memory_clear(mem, true); @@ -928,14 +928,14 @@ bool common_speculative_is_compat(llama_context * ctx_tgt) { int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size())); if (ret != 0) { LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret); - res = false; + res = COMMON_SPECULATIVE_COMPAT_TYPE_NO; goto done; } // try to remove the last tokens if (!llama_memory_seq_rm(mem, 0, 1, -1)) { LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__); - res = false; + res = COMMON_SPECULATIVE_COMPAT_TYPE_CKPT; goto done; } @@ -1262,7 +1262,6 @@ struct common_speculative_session::impl { if (draft.empty()) { // switch to non-draft inference LOG_DBG("%s: draft of length 0 after denied checkpoint\n", __func__); - clear_draft(); return draft; } // we use the shortened draft of previous speculation diff --git a/common/speculative.h b/common/speculative.h index 7fb9c46a04..3d40110c45 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -23,9 +23,15 @@ enum common_speculative_type common_speculative_type_from_name(const std::string // convert type to string std::string common_speculative_type_to_str(enum common_speculative_type type); +enum common_speculative_compat_type { + COMMON_SPECULATIVE_COMPAT_TYPE_NO = 0, + COMMON_SPECULATIVE_COMPAT_TYPE_FULL = 1, + COMMON_SPECULATIVE_COMPAT_TYPE_CKPT = 2, +}; + // check if the llama_context is compatible for speculative decoding // note: clears the memory of the context -bool common_speculative_is_compat(llama_context * ctx_tgt); +common_speculative_compat_type common_speculative_is_compat(llama_context * ctx_tgt); common_speculative * common_speculative_init( common_params_speculative & params, diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 45f69cb4bf..f0485f8e84 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -860,9 +860,14 @@ private: slots.clear(); - const bool can_spec = common_speculative_is_compat(ctx); - if (!can_spec) { - SRV_WRN("%s", "speculative decoding not supported by this context without checkpoints\n"); + const auto spec_type = common_speculative_is_compat(ctx); + if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_NO) { + SRV_WRN("%s", "speculative decoding not supported by this context\n"); + } + + if (spec_type == COMMON_SPECULATIVE_COMPAT_TYPE_CKPT) { + SRV_WRN("%s", "speculative decoding will use checkpoints\n"); + params_base.speculative.use_checkpoints = true; } // initialize slots @@ -881,7 +886,7 @@ private: slot.prompt.tokens.has_mtmd = mctx != nullptr; // try speculative decoding - if (can_spec || params_base.speculative.use_checkpoints) { + if (spec_type != COMMON_SPECULATIVE_COMPAT_TYPE_NO) { if (mctx) { SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); return false;