From 72d3b1898a9c81152710cc37dd1dfd26764055d9 Mon Sep 17 00:00:00 2001 From: Sascha Rogmann <59577610+srogmann@users.noreply.github.com> Date: Wed, 28 Jan 2026 18:42:42 +0100 Subject: [PATCH] =?UTF-8?q?spec=20:=20add=20self=E2=80=91speculative=20dec?= =?UTF-8?q?oding=20(no=20draft=20model=20required)=20+=20refactor=20(#1847?= =?UTF-8?q?1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * server: introduce self-speculative decoding * server: moved self-call into speculative.cpp * can_speculate() includes self-speculation Co-authored-by: Georgi Gerganov * server: can_speculate() tests self-spec * server: replace can_speculate() with slot.can_speculate() Co-authored-by: Sigbjørn Skjæret * common: use %zu format specifier for size_t in logging Co-authored-by: Sigbjørn Skjæret * server: can_speculate() requires a task instance * common: ngram map, config self-speculative decoding * common: add enum common_speculative_type * common: add vector of speculative states * common: add option --spec-draftless * server: cleanup (remove slot.batch_spec, rename) * common: moved self-spec impl to ngram-map * common: cleanup (use common_speculative_state_draft) * spec : refactor * cont : naming * spec: remove --spec-config * doc: (draftless) speculative decoding * common: print performance in spec decoding * minor : cleanup * common : better names * minor : cleanup + fix build * minor: comments * CODEOWNERS: add common/ngram-map.* (#18471) * common : rename speculative.draftless_type -> speculative.type * ngram-map : fix uninitialized values * ngram-map : take into account the input can become shorter * ngram-map : revert len check for now * arg : change `--spec-draftless` -> `--spec-type` * spec : add common_speculative_state::accept() * spec : refactor + add common_speculative_begin() * spec : fix begin() call with mtmd * spec : additional refactor + remove common_speculative_params --------- Co-authored-by: Georgi Gerganov Co-authored-by: Sigbjørn Skjæret --- CODEOWNERS | 1 + common/CMakeLists.txt | 2 + common/arg.cpp | 87 +- common/common.cpp | 9 +- common/common.h | 64 +- common/ngram-cache.cpp | 7 +- common/ngram-cache.h | 4 +- common/ngram-map.cpp | 367 ++++++ common/ngram-map.h | 105 ++ common/speculative.cpp | 1012 +++++++++++++---- common/speculative.h | 44 +- docs/speculative.md | 120 ++ examples/lookup/lookup-create.cpp | 4 +- examples/lookup/lookup-stats.cpp | 10 +- examples/lookup/lookup.cpp | 12 +- .../speculative-simple/speculative-simple.cpp | 72 +- examples/speculative/speculative.cpp | 4 +- tools/server/server-context.cpp | 146 +-- tools/server/server-task.cpp | 23 + 19 files changed, 1649 insertions(+), 444 deletions(-) create mode 100644 common/ngram-map.cpp create mode 100644 common/ngram-map.h create mode 100644 docs/speculative.md diff --git a/CODEOWNERS b/CODEOWNERS index 6086abb564..e573a3d2e6 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -18,6 +18,7 @@ /common/jinja/ @ngxson @CISC @aldehir /common/llguidance.* @ggerganov /common/log.* @ggerganov +/common/ngram-map.* @srogmann /common/peg-parser.* @aldehir /common/sampling.* @ggerganov /common/speculative.* @ggerganov diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index ae02c0bd77..3bc7bc6210 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -73,6 +73,8 @@ add_library(${TARGET} STATIC log.h ngram-cache.cpp ngram-cache.h + ngram-map.cpp + ngram-map.h peg-parser.cpp peg-parser.h preset.cpp diff --git a/common/arg.cpp b/common/arg.cpp index cd3a1b6397..a685c418bf 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -6,6 +6,7 @@ #include "json-schema-to-grammar.h" #include "log.h" #include "sampling.h" +#include "speculative.h" #include "preset.h" // fix problem with std::min and std::max @@ -579,14 +580,14 @@ static bool common_params_parse_ex(int argc, char ** argv, common_params_context params.mmproj = res.mmproj; } // only download mmproj if the current example is using it - for (auto & ex : mmproj_examples) { + for (const auto & ex : mmproj_examples) { if (ctx_arg.ex == ex) { common_params_handle_model(params.mmproj, params.hf_token, params.offline); break; } } - common_params_handle_model(params.speculative.model, params.hf_token, params.offline); - common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); + common_params_handle_model(params.speculative.mparams_dft, params.hf_token, params.offline); + common_params_handle_model(params.vocoder.model, params.hf_token, params.offline); } // model is required (except for server) @@ -1216,16 +1217,16 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-lcs", "--lookup-cache-static"}, "FNAME", "path to static lookup cache to use for lookup decoding (not updated by generation)", [](common_params & params, const std::string & value) { - params.lookup_cache_static = value; + params.speculative.lookup_cache_static = value; } - ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-lcd", "--lookup-cache-dynamic"}, "FNAME", "path to dynamic lookup cache to use for lookup decoding (updated by generation)", [](common_params & params, const std::string & value) { - params.lookup_cache_dynamic = value; + params.speculative.lookup_cache_dynamic = value; } - ).set_examples({LLAMA_EXAMPLE_LOOKUP})); + ).set_examples({LLAMA_EXAMPLE_LOOKUP, LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-c", "--ctx-size"}, "N", string_format("size of the prompt context (default: %d, 0 = loaded from model)", params.n_ctx), @@ -2563,7 +2564,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-hfd", "-hfrd", "--hf-repo-draft"}, "/[:quant]", "Same as --hf-repo, but for the draft model (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.model.hf_repo = value; + params.speculative.mparams_dft.hf_repo = value; } ).set_env("LLAMA_ARG_HFD_REPO")); add_opt(common_arg( @@ -3384,7 +3385,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"-md", "--model-draft"}, "FNAME", "draft model for speculative decoding (default: unused)", [](common_params & params, const std::string & value) { - params.speculative.model.path = value; + params.speculative.mparams_dft.path = value; } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}).set_env("LLAMA_ARG_MODEL_DRAFT")); add_opt(common_arg( @@ -3394,6 +3395,66 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.replacements.push_back({ tgt, dft }); } ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); + add_opt(common_arg( + {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]", + string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", + common_speculative_type_to_str(params.speculative.type).c_str()), + [](common_params & params, const std::string & value) { + if (value == "none") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NONE; + } else if (value == "ngram-cache") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_CACHE; + } else if (value == "ngram-simple") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE; + } else if (value == "ngram-map-k") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; + } else if (value == "ngram-map-k4v") { + params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; + } else { + throw std::invalid_argument("unknown speculative decoding type without draft model"); + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-size-n"}, "N", + string_format("ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: %d)", params.speculative.ngram_size_n), + [](common_params & params, int value) { + if (value < 1 || value > 1024) { + throw std::invalid_argument("ngram size N must be between 1 and 1024 inclusive"); + } + params.speculative.ngram_size_n = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-size-m"}, "N", + string_format("ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: %d)", params.speculative.ngram_size_m), + [](common_params & params, int value) { + if (value < 1 || value > 1024) { + throw std::invalid_argument("ngram size M must be between 1 and 1024 inclusive"); + } + params.speculative.ngram_size_m = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-check-rate"}, "N", + string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate), + [](common_params & params, int value) { + if (value < 1) { + throw std::invalid_argument("ngram check rate must be at least 1"); + } + params.speculative.ngram_check_rate = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); + add_opt(common_arg( + {"--spec-ngram-min-hits"}, "N", + string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits), + [](common_params & params, int value) { + if (value < 1) { + throw std::invalid_argument("ngram min hits must be at least 1"); + } + params.speculative.ngram_min_hits = value; + } + ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"-ctkd", "--cache-type-k-draft"}, "TYPE", string_format( @@ -3620,8 +3681,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; - params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.port = 8012; params.n_ubatch = 1024; params.n_batch = 1024; @@ -3636,8 +3697,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.model.hf_repo = "ggml-org/Qwen2.5-Coder-14B-Q8_0-GGUF"; params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; - params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; - params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; + params.speculative.mparams_dft.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; + params.speculative.mparams_dft.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.port = 8012; params.n_ubatch = 1024; params.n_batch = 1024; diff --git a/common/common.cpp b/common/common.cpp index 26250abb6c..3aa396127c 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1097,7 +1097,10 @@ common_init_result::common_init_result(common_params & params) : if (params.fit_params) { LOG_INF("%s: fitting params to device memory, for bugs during this step try to reproduce them with -fit off, or provide --verbose logs if the bug only occurs with -fit on\n", __func__); llama_params_fit(params.model.path.c_str(), &mparams, &cparams, - params.tensor_split, params.tensor_buft_overrides.data(), params.fit_params_target.data(), params.fit_params_min_ctx, + params.tensor_split, + params.tensor_buft_overrides.data(), + params.fit_params_target.data(), + params.fit_params_min_ctx, params.verbosity >= 4 ? GGML_LOG_LEVEL_DEBUG : GGML_LOG_LEVEL_ERROR); } @@ -1208,10 +1211,6 @@ std::vector & common_init_result::lora() { return pimpl->lora; } -void common_init_result::free_context() { - pimpl->context.reset(); -} - common_init_result_ptr common_init_from_params(common_params & params) { common_init_result_ptr res(new common_init_result(params)); diff --git a/common/common.h b/common/common.h index 21c11f457d..fd3ab8cd18 100644 --- a/common/common.h +++ b/common/common.h @@ -164,6 +164,16 @@ enum common_params_sampling_config : uint64_t { COMMON_PARAMS_SAMPLING_CONFIG_MIROSTAT_ETA = 1 << 11, }; +enum common_speculative_type { + COMMON_SPECULATIVE_TYPE_NONE, // no speculative decoding + COMMON_SPECULATIVE_TYPE_DRAFT, // draft model + COMMON_SPECULATIVE_TYPE_EAGLE3, // eagle draft model + COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache + COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type +}; // sampling parameters struct common_params_sampling { @@ -243,16 +253,35 @@ struct common_params_model { }; struct common_params_speculative { - std::vector devices; // devices to use for offloading + common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding - int32_t n_ctx = 0; // draft context size - int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding - int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding - int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) - float p_split = 0.1f; // speculative decoding split probability - float p_min = 0.75f; // minimum speculative decoding probability (greedy) - std::vector> replacements; // main to speculative model replacements - std::vector tensor_buft_overrides; + // general-purpose speculative decoding parameters + + int32_t n_max = 16; // maximum number of tokens to draft during speculative decoding + int32_t n_min = 0; // minimum number of draft tokens to use for speculative decoding + float p_split = 0.1f; // speculative decoding split probability + float p_min = 0.75f; // minimum speculative decoding probability (greedy) + + // ngram-based speculative decoding + + uint16_t ngram_size_n = 12; // ngram size for lookup + uint16_t ngram_size_m = 48; // mgram size for speculative tokens + uint16_t ngram_check_rate = 1; // check rate for ngram lookup + uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed + + std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT + std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT + + // draft-model speculative decoding + + struct common_params_model mparams_dft; + + llama_model * model_dft = nullptr; // a llama_model that can be shared by multiple speculative contexts + + llama_context_params cparams_dft; // these are the parameters for the draft llama_context + + int32_t n_ctx = 0; // draft context size + int32_t n_gpu_layers = -1; // number of layers to store in VRAM for the draft model (-1 - use default) ggml_type cache_type_k = GGML_TYPE_F16; // KV cache data type for the K ggml_type cache_type_v = GGML_TYPE_F16; // KV cache data type for the V @@ -260,7 +289,14 @@ struct common_params_speculative { struct cpu_params cpuparams; struct cpu_params cpuparams_batch; - struct common_params_model model; + std::vector devices; // devices to use for offloading + + std::vector> replacements; // main to speculative model replacements + std::vector tensor_buft_overrides; + + bool has_dft() const { + return !mparams_dft.path.empty() || !mparams_dft.hf_repo.empty(); + } }; struct common_params_vocoder { @@ -378,8 +414,6 @@ struct common_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_suffix = ""; // string to suffix user inputs with // NOLINT - std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT - std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT // llama-debug specific options @@ -575,10 +609,6 @@ struct common_params { // return false from callback to abort model loading or true to continue llama_progress_callback load_progress_callback = NULL; void * load_progress_callback_user_data = NULL; - - bool has_speculative() const { - return !speculative.model.path.empty() || !speculative.model.hf_repo.empty(); - } }; // call once at the start of a program if it uses libcommon @@ -714,8 +744,6 @@ struct common_init_result { std::vector & lora(); - void free_context(); - private: struct impl; std::unique_ptr pimpl; diff --git a/common/ngram-cache.cpp b/common/ngram-cache.cpp index d1a4d84c40..dce54b3647 100644 --- a/common/ngram-cache.cpp +++ b/common/ngram-cache.cpp @@ -192,12 +192,12 @@ void common_ngram_cache_draft( break; } - LOG(" - draft candidate: token=%d\n", drafted_token); + LOG_DBG(" - draft candidate: token=%d\n", drafted_token); draft.push_back(drafted_token); } } -void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename) { +void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename) { std::ofstream file_out(filename, std::ios::binary); for (std::pair item : ngram_cache) { const common_ngram ngram = item.first; @@ -217,10 +217,9 @@ void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & fil file_out.write(reinterpret_cast(&count), sizeof(int32_t)); } } - } -common_ngram_cache common_ngram_cache_load(std::string & filename) { +common_ngram_cache common_ngram_cache_load(const std::string & filename) { std::ifstream hashmap_file(filename, std::ios::binary); if (!hashmap_file) { throw std::ifstream::failure("Unable to open file " + filename); diff --git a/common/ngram-cache.h b/common/ngram-cache.h index dfe012abe4..6e7cfea966 100644 --- a/common/ngram-cache.h +++ b/common/ngram-cache.h @@ -88,12 +88,12 @@ void common_ngram_cache_draft( // Save an ngram cache to a file. // ngram_cache: the ngram cache to save. // filename: the path under which to save the ngram cache. -void common_ngram_cache_save(common_ngram_cache & ngram_cache, std::string & filename); +void common_ngram_cache_save(common_ngram_cache & ngram_cache, const std::string & filename); // Load an ngram cache saved with common_ngram_cache_save. // filename: the path from which to load the ngram cache. // returns: an ngram cache containing the information saved to filename. -common_ngram_cache common_ngram_cache_load(std::string & filename); +common_ngram_cache common_ngram_cache_load(const std::string & filename); // Merge two ngram caches. // ngram_cache_target: the ngram cache to which to add the information from ngram_cache_add. diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp new file mode 100644 index 0000000000..930e7a3c10 --- /dev/null +++ b/common/ngram-map.cpp @@ -0,0 +1,367 @@ +#include "common.h" +#include "log.h" +#include "ngram-map.h" + +#include +#include +#include +#include + +// n-gram simple +// + +/** + * Perform speculative generation using the model's own token history. + * Searches for a matching pattern in the token history and returns draft tokens. + * + * @param state Current state of this implementation + * @param tokens Token history to search in + * @param sampled Last sampled token + * @return Vector of draft tokens, empty if no matching pattern is found + */ +llama_tokens common_ngram_simple_draft( + common_ngram_simple_state & state, + const llama_tokens & tokens, llama_token sampled) { + + // Simple implementation of self-speculative decoding without a draft model. + // + const size_t cur_len = tokens.size(); + // Only check every check_rate tokens to save compute + // i.e., perform check if (cur_len - idx_last_check) >= check_rate + if (state.idx_last_check + state.config.check_rate > cur_len) { + llama_tokens draft_tokens; + return draft_tokens; + } + + size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history + size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft + + // vector for tokens we want to verify. + // return empty vector if there is no match. + llama_tokens draft_tokens; + + // We need at least n_draft_min + n_draft_max + 1 tokens. + if (cur_len <= static_cast(n_draft_min + n_draft_max + 1)) { + return draft_tokens; + } + + // pattern search + llama_tokens pattern; + pattern.reserve(n_draft_min); + for (size_t j = cur_len - n_draft_min + 1; j < cur_len; ++j) { + pattern.push_back(tokens[j]); + } + pattern.push_back(sampled); // add the last token to the pattern + + // We do a search in the token history. + state.idx_last_check = cur_len; + + size_t match_pos = 0; // we ignore position 0, position 0 == no match + // search backwards, but skip the current match (we are currently there) + for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) { + bool match = true; + for (size_t k = 0; k < pattern.size(); ++k) { + if (tokens[j + k] != pattern[k]) { + match = false; + break; + } + } + if (match) { + match_pos = j; + break; + } + } + if (match_pos == 0) { + return draft_tokens; + } + + const size_t copy_max = std::min( + n_draft_max, + cur_len - (match_pos + n_draft_min) + ); + if (copy_max < n_draft_min) { + return draft_tokens; + } + LOG_DBG("%s: #tokens = %zu: found matching pattern at pos %zu, length %zu, draft length %zu\n", + __func__, cur_len, + match_pos, pattern.size(), copy_max); + + draft_tokens.reserve(copy_max); + for (size_t j = 0; j < copy_max; ++j) { + draft_tokens.push_back(tokens[match_pos + n_draft_min + j]); + } + return draft_tokens; +} + + +// n-gram map +// + +// maximum number of counted values of a ngram map value. +#define COMMON_NGRAM_MAX_VALUE_COUNT 16380 + +static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length); + +void common_ngram_map_draft(common_ngram_map & map, + const llama_tokens & inp, llama_token sampled, + llama_tokens & draft) { + // reset last key and value. + map.last_draft_created = false; + map.last_draft_key_idx = 0; + map.last_draft_value_idx = 0; + + const size_t cur_len = inp.size(); + const uint16_t n = map.size_key; + const uint16_t m = map.size_value; + if (cur_len < static_cast(2 * n + m)) { + return; + } + + // Only check every check_rate tokens to save compute + // i.e., perform check if (cur_len - idx_last_check) >= check_rate + if (map.idx_last_check + map.check_rate > cur_len) { + return; + } + map.idx_last_check = cur_len; + + // search pattern, the key n-gram + std::vector key_tokens; + key_tokens.reserve(n); + for (size_t j = cur_len - n + 1; j < cur_len; ++j) { + key_tokens.push_back(inp[j]); + } + key_tokens.push_back(sampled); + + // search for the key in the map + size_t match_pos = 0; + for (size_t j = cur_len - n - m - 1; j > 0; --j) { + bool match = true; + for (size_t k = 0; k < n; ++k) { + if (inp[j + k] != key_tokens[k]) { + match = false; + break; + } + } + if (match) { + match_pos = j; + break; + } + } + if (match_pos > 0) { + LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__, + cur_len, n, m, key_tokens.size(), sampled, match_pos); + } + + if (match_pos == 0) { + return; + } + + // We have a match, now we look for the statistics of the key. + size_t key_offset = map.keys.size(); // offset in the map + // We iterate through the std::vector map->keys. + for (size_t i = 0; i < map.keys.size(); ++i) { + bool match = true; + for (size_t j = 0; j < n; ++j) { + if (inp[map.keys[i].key_idx + j] != key_tokens[j]) { + match = false; + break; + } + } + if (match) { + key_offset = i; + break; + } + } + if (key_offset == map.keys.size()) { + // We create a new key-entry, it will get offset key_offset. + common_ngram_map_key new_key; + new_key.key_idx = match_pos; + new_key.stat_idx = 0; + new_key.key_num = 0; + for (int i = 0; i < COMMON_NGRAM_MAX_VALUES; ++i) { + new_key.values[i].value_num = 0; + new_key.values[i].n_accepted = m; + } + map.keys.push_back(new_key); + } + + // our key n-gram: + common_ngram_map_key & curr_key = map.keys[key_offset]; + + // update number of key hits + curr_key.key_num = (uint16_t) std::min((int) map.keys[key_offset].key_num + 1, + (int) COMMON_NGRAM_MAX_VALUE_COUNT); + + if (map.key_only) { + // simple mode: + // Fill in the draft with the m tokens following the key. + // We work with value values[0] only. + int n_draft_tokens = std::min((int) m, (int) curr_key.values[0].n_accepted); + + for (int i = 0; i < n_draft_tokens; ++i) { + draft.push_back(inp[match_pos + n + i]); + } + + LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, + key_offset, curr_key.key_num, draft.size()); + + map.last_draft_created = false; + map.last_draft_key_idx = key_offset; + map.last_draft_value_idx = 0; // value 0 is used for simple mode + return; + } + + if (curr_key.key_num < map.min_hits) { + // not enough hits to consider this a good draft + LOG_DBG("%s: key_offset = %zu, key_num = %d, min_hits = %d, no draft\n", __func__, + key_offset, curr_key.key_num, map.min_hits); + return; + } + + // complex mode: examine the different m-grams after this key n-gram. + // + + // determine all (max COMMON_NGRAM_MAX_VALUES) m-grams after the key n-gram. + for (size_t i = curr_key.stat_idx; i <= match_pos; ++i) { + // begins the key n-gram at index i? + bool match_key = true; + for (size_t k = 0; k < n; ++k) { + if (inp[i + k] != key_tokens[k]) { + match_key = false; + break; + } + } + if (!match_key) { + continue; + } + + // Do we haven a existing value m-gram or a new one after the key at index i? + size_t idx_begin_value_key = i + n; + int idx_value = -1; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + size_t idx_begin_value_v = curr_key.values[v].value_idx; + if (idx_begin_value_v == 0) { + // We found an empty value slot => we found a new value m-gram after the key n-gram. + curr_key.values[v].value_idx = idx_begin_value_key; + curr_key.values[v].value_num = 0; + curr_key.values[v].n_accepted = m; + idx_value = v; + break; + } + bool match = true; + for (size_t j = 0; j < m; ++j) { + if (inp[idx_begin_value_key + j] != inp[idx_begin_value_v + j]) { + match = false; + break; + } + } + if (match) { + // We found an existing value m-gram after the key n-gram. + idx_value = v; + break; + } + } + if (idx_value >= 0) { + // We found a value m-gram of the key n-gram. + curr_key.values[idx_value].value_num = (uint16_t) std::min((int) curr_key.values[idx_value].value_num + 1, + (int) COMMON_NGRAM_MAX_VALUE_COUNT); + } + } + // the statistics are updated up to match_pos. + curr_key.stat_idx = match_pos; + + // Do we have a value we could use for the draft? + uint16_t max_occur = 0; + int slot_max = 0; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + uint16_t curr_occur = curr_key.values[v].value_num; + if (curr_occur > max_occur) { + max_occur = curr_occur; + slot_max = v; + } + } + // What is sum of the other occurences? + uint32_t sum_occur = 0; + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + if (v == slot_max) { + continue; + } + uint16_t curr_occur = curr_key.values[v].value_num; + sum_occur += curr_occur; + } + + LOG_INF("%s: key_offset = %zu, max_occur = %d, sum_occur = %d, slot_max = %d [%zu/%d, %zu/%d, %zu/%d, %zu/%d]\n", __func__, + key_offset, + max_occur, sum_occur, slot_max, + curr_key.values[0].value_idx, curr_key.values[0].value_num, + curr_key.values[1].value_idx, curr_key.values[1].value_num, + curr_key.values[2].value_idx, curr_key.values[2].value_num, + curr_key.values[3].value_idx, curr_key.values[3].value_num + ); + // Print the tokens of the four values (if idx != 0), use LOG_INF + for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { + if (curr_key.values[v].value_idx != 0) { + LOG_INF("%s: value[%d] = %s\n", __func__, v, common_tokens_to_str(inp, curr_key.values[v].value_idx, m).c_str()); + } + } + + if (sum_occur > 0 && max_occur < 3 * sum_occur) { + // The most frequent value is not much more frequent than the other values. + // We do not use the draft. + return; + } + + // We use the most frequent value values[slot_max] for the draft. + // Fill in the draft with the m tokens following the key. + int n_draft_tokens = std::min((int) m, (int) curr_key.values[slot_max].n_accepted); + + for (int i = 0; i < n_draft_tokens; ++i) { + draft.push_back(inp[match_pos + n + i]); + } + + LOG_INF("%s: key_offset = %zu, slot_max = %d, key_num = %d, draft.size = %zu\n", __func__, + key_offset, slot_max, + curr_key.key_num, draft.size()); + + map.last_draft_created = true; + map.last_draft_key_idx = key_offset; + map.last_draft_value_idx = slot_max; // value used for draft generation. +} + +void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) { + if (!map.last_draft_created) { + return; + } + + // find the key and its chosen value. + const size_t key_idx = map.last_draft_key_idx; + const size_t val_idx = map.last_draft_value_idx; + + // find key corresponding to key_idx. + common_ngram_map_key & curr_key = map.keys[key_idx]; + // find value corresponding to val_idx. + struct common_ngram_map_value & curr_value = curr_key.values[val_idx]; // value used for draft generation. + + // update the value statistics + LOG_INF("common_ngram_map_send_accepted: n_accepted = %d, prev value_num = %d\n", + n_accepted, curr_value.n_accepted); + curr_value.n_accepted = n_accepted; +} + +// Helper functions. +// + +// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...]. +std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) { + std::ostringstream oss; + oss << '['; + for (size_t i = 0; i < length; ++i) { + if (i > 0) { + oss << ", "; + } + oss << inp[start + i]; + } + oss << ']'; + return oss.str(); +} + diff --git a/common/ngram-map.h b/common/ngram-map.h new file mode 100644 index 0000000000..bf91883f0c --- /dev/null +++ b/common/ngram-map.h @@ -0,0 +1,105 @@ +#pragma once +// +// common/ngram-map.h: structures used to manage a map from n-grams to a list of m-grams +// +// These structures are used to do a lookup of n-grams followed by m-grams in token history. +// +// There are two algorithms implemented: +// 1. ngram_simple: lookup of n-grams followed by m-grams in token history. +// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map. +// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams. +// + +#include "llama.h" + +#include + +// n-gram simple +// + +// config of n-gram simple. +struct common_ngram_simple_config { + uint16_t size_ngram; // size of n-grams to lookup in self-mode + uint16_t size_mgram; // size of m-grams to draft in self-mode + uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token +}; + +// current state (and config) of n-gram simple. +struct common_ngram_simple_state { + common_ngram_simple_config config; + + size_t idx_last_check = 0; // index of last check in context history (mutable) + + common_ngram_simple_state(const common_ngram_simple_config & config) + : config(config) {} +}; + +// Searches for a n-gram in the history and checks whether a draft sequence should be generated. +// state: the ngram simple state to search in. +// inp: the tokens generated so far. +// sampled: the token that was just sampled. +// draft: vector to store the draft tokens, initially empty. +llama_tokens common_ngram_simple_draft( + common_ngram_simple_state & state, + const llama_tokens & tokens, llama_token sampled); + + +// n-gram map +// + +// maximum number of m-gram values stored for each key n-gram. +#define COMMON_NGRAM_MAX_VALUES 4 + +// statistics of a m-gram after a known n-gram +struct common_ngram_map_value { + size_t value_idx = 0; // index of value m-gram in token-history (0 if unused) + uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot) + int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused) +}; + +// statistics of a n-gram +struct common_ngram_map_key { + size_t key_idx; // index of key n-gram in token-history + size_t stat_idx; // index of last token of stastistics computation (key_num, values) + + uint16_t key_num; // number of occurences of this key n-gram in token-history + common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key +}; + +// map from n-grams to following m-grams in token-history +struct common_ngram_map { + uint16_t size_key; // size of key n-grams + uint16_t size_value; // size of value m-grams + + bool key_only; // true if only key n-grams are used, no values. + + // first draft: vector only, no map. + std::vector keys; // key n-grams which occur several times in token-history + uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token + uint16_t min_hits; // minimum number of key hits to consider a draft + + common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys, + uint16_t check_rate, uint16_t min_hits) + : size_key(sz_key), size_value(sz_value), key_only(only_keys), + check_rate(check_rate), min_hits(min_hits) {} + + bool last_draft_created = false; // true if a draft was created at last call. + size_t last_draft_key_idx = 0; // index of last key used for draft generation. + uint16_t last_draft_value_idx = 0; // index of last value used for draft generation. + + size_t idx_last_check = 0; // index of last check in context history +}; + + +// Searches for the n-gram in the history and checks whether a draft sequence should be generated. +// map: the ngram map to search in. +// inp: the tokens generated so far. +// sampled: the token that was just sampled. +// draft: vector to store the draft tokens, initially empty. +void common_ngram_map_draft( + common_ngram_map & map, + const llama_tokens & inp, llama_token sampled, + llama_tokens & draft); + +// Update the statistics of a value after a draft was processed. +void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted); diff --git a/common/speculative.cpp b/common/speculative.cpp index 3e83b0964c..3f314b5d57 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -1,99 +1,54 @@ #include "speculative.h" +#include "common.h" #include "ggml.h" #include "llama.h" #include "log.h" -#include "common.h" +#include "ngram-cache.h" +#include "ngram-map.h" #include "sampling.h" -#include #include +#include +#include #include #define SPEC_VOCAB_MAX_SIZE_DIFFERENCE 128 #define SPEC_VOCAB_CHECK_START_TOKEN_ID 5 -struct common_speculative { - struct llama_context * ctx_tgt; // only used for retokenizing from ctx_dft - struct llama_context * ctx_dft; - struct common_sampler * smpl; - - llama_batch batch; - llama_tokens prompt_dft; - bool vocab_dft_compatible = true; // whether retokenization is needed - std::map tgt_dft_replacements = {}; +const std::vector common_speculative_types = { + COMMON_SPECULATIVE_TYPE_NONE, + COMMON_SPECULATIVE_TYPE_DRAFT, + COMMON_SPECULATIVE_TYPE_EAGLE3, + COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, + COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, + COMMON_SPECULATIVE_TYPE_NGRAM_CACHE }; -struct common_speculative * common_speculative_init( - struct llama_context * ctx_tgt, - struct llama_context * ctx_dft) { - auto * result = new common_speculative { - /* .ctx_tgt = */ ctx_tgt, - /* .ctx_dft = */ ctx_dft, - /* .smpl = */ nullptr, - /* .batch = */ llama_batch_init(llama_n_batch(ctx_dft), 0, 1), - /* .prompt_dft = */ {}, - /* .vocab_dft_compatible = */ false, - }; +const std::map common_speculative_type_from_name_map = { + {"none", COMMON_SPECULATIVE_TYPE_NONE}, + {"draft", COMMON_SPECULATIVE_TYPE_DRAFT}, + {"eagle3", COMMON_SPECULATIVE_TYPE_EAGLE3}, + {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, + {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, + {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, + {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} +}; - // TODO: optimize or pass from outside? -#if 0 - { - common_params_sampling params; - params.no_perf = false; +struct common_speculative_config { + common_speculative_type type; + common_params_speculative params; - params.top_k = 40; - params.top_p = 0.9; + common_speculative_config(common_speculative_type t, + const common_params_speculative & p = common_params_speculative{}) : type(t), params(p) {} +}; - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - COMMON_SAMPLER_TYPE_TOP_P, - COMMON_SAMPLER_TYPE_INFILL, - }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } -#else - { - common_params_sampling params; - params.no_perf = false; - - params.top_k = 10; - - params.samplers = { - COMMON_SAMPLER_TYPE_TOP_K, - }; - - result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); - } -#endif - - result->vocab_dft_compatible = common_speculative_are_compatible(ctx_tgt, ctx_dft); - LOG_DBG("vocab_dft_compatible = %d\n", result->vocab_dft_compatible); - - return result; -} - -void common_speculative_free(struct common_speculative * spec) { - if (spec == nullptr) { - return; - } - - common_sampler_free(spec->smpl); - - llama_batch_free(spec->batch); - - delete spec; -} - -bool common_speculative_are_compatible( - const struct llama_context * ctx_tgt, - const struct llama_context * ctx_dft) { - const struct llama_model * model_tgt = llama_get_model(ctx_tgt); - const struct llama_model * model_dft = llama_get_model(ctx_dft); - - const struct llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); - const struct llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); +static bool common_speculative_are_compatible( + const llama_model * model_tgt, + const llama_model * model_dft) { + const llama_vocab * vocab_tgt = llama_model_get_vocab(model_tgt); + const llama_vocab * vocab_dft = llama_model_get_vocab(model_dft); const bool vocab_type_tgt = llama_vocab_type(vocab_tgt); LOG_DBG("%s: vocab_type tgt: %d\n", __func__, vocab_type_tgt); @@ -134,11 +89,12 @@ bool common_speculative_are_compatible( for (int i = SPEC_VOCAB_CHECK_START_TOKEN_ID; i < std::min(n_vocab_tgt, n_vocab_dft); ++i) { const char * token_text_tgt = llama_vocab_get_text(vocab_tgt, i); const char * token_text_dft = llama_vocab_get_text(vocab_dft, i); + if (std::strcmp(token_text_tgt, token_text_dft) != 0) { LOG_DBG("%s: draft model vocab must match target model to use speculation but ", __func__); LOG_DBG("token %d content differs - target '%s', draft '%s'\n", i, - common_token_to_piece(ctx_tgt, i).c_str(), - common_token_to_piece(ctx_dft, i).c_str()); + common_token_to_piece(vocab_tgt, i).c_str(), + common_token_to_piece(vocab_dft, i).c_str()); return false; } } @@ -147,215 +103,779 @@ bool common_speculative_are_compatible( return true; } -void common_speculative_add_replacement_tgt_dft( - struct common_speculative * spec, - const char *source, const char *dest) { - spec->tgt_dft_replacements[source] = dest; -} +// state of an implementation of speculative decoding +// +// each implementation has a unique type and a state that is implementation-specific +// in a subclass of common_speculative_state +struct common_speculative_state { + const enum common_speculative_type type; -static std::string replace_to_dft( - struct common_speculative * spec, - const std::string& input) { - std::string result = input; - for (const auto & pair : spec->tgt_dft_replacements) { - size_t pos = result.find(pair.first); - while (pos != std::string::npos) { - result.replace(pos, pair.first.length(), pair.second); - pos = result.find(pair.first, pos + pair.second.length()); - } - } - return result; -} + size_t drafts_call_count = 0; // number of times this implementation was called. + size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation. + size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model. + size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation. + size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model. -static std::string replace_to_tgt( - struct common_speculative * spec, - const std::string& input) { - std::string result = input; - for (const auto& pair : spec->tgt_dft_replacements) { - size_t pos = result.find(pair.second); - while (pos != std::string::npos) { - result.replace(pos, pair.second.length(), pair.first); - pos = result.find(pair.second, pos + pair.first.length()); - } - } - return result; -} + // TODO: track performance of most recent calls + const bool gen_perf = true; // whether to generate performance stats. + int64_t gen_duration_us = 0; // total time spent in this implementation in microseconds. -llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt_tgt_main_model, // specified in target model vocab - llama_token id_last) { - auto & batch = spec->batch; - auto & ctx_tgt = spec->ctx_tgt; - auto & ctx_dft = spec->ctx_dft; - auto & smpl = spec->smpl; - auto & prompt_dft = spec->prompt_dft; + common_speculative_state(enum common_speculative_type type) : type(type) {} - auto * mem_dft = llama_get_memory(ctx_dft); + virtual ~common_speculative_state() = default; - int reuse_i = 0; - int reuse_n = 0; + virtual void begin(const llama_tokens & prompt) = 0; - const int n_ctx = llama_n_ctx(ctx_dft) - params.n_draft; + virtual void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) = 0; - llama_tokens prompt_tgt_draft_model; - if (!spec->vocab_dft_compatible) { - std::string text; - text = common_detokenize(ctx_tgt, prompt_tgt_main_model, true); - text = replace_to_dft(spec, text); - LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); - prompt_tgt_draft_model = common_tokenize(ctx_dft, text, false, true); + virtual void accept(uint16_t n_accepted) = 0; +}; - // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation - const auto * model_tgt = llama_get_model(ctx_tgt); - const auto * vocab_tgt = llama_model_get_vocab(model_tgt); +struct common_speculative_state_draft : public common_speculative_state { + llama_context * ctx_tgt; // only used for retokenizing from ctx_dft + llama_context * ctx_dft; - int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); - GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); - text.resize(-n_chars); - llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); - text = replace_to_dft(spec, text); + common_sampler * smpl; - LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); - id_last = common_tokenize(ctx_dft, text, false, true)[0]; - } - // prompt_tgt's tokens will always be compatible with ctx_dft - const llama_tokens &prompt_tgt = - spec->vocab_dft_compatible ? prompt_tgt_main_model : prompt_tgt_draft_model; + llama_batch batch; + llama_tokens prompt_dft; - const int i_start = std::max(0, (int) prompt_tgt.size() - n_ctx); + bool vocab_cmpt = true; // whether retokenization is needed + std::unordered_map vocab_map; - // reuse as much as possible from the old draft context - // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt - for (int i = 0; i < (int) prompt_dft.size(); ++i) { - int cur = 0; - while (i_start + cur < (int) prompt_tgt.size() && - i + cur < (int) prompt_dft.size() && - prompt_tgt[i_start + cur] == prompt_dft[i + cur]) { - cur++; + common_speculative_state_draft( + enum common_speculative_type type, + llama_context * ctx_tgt, + llama_context * ctx_dft, + const std::vector> & replacements) + : common_speculative_state(type) + , ctx_tgt(ctx_tgt) + , ctx_dft(ctx_dft) + { + batch = llama_batch_init(llama_n_batch(ctx_dft), 0, 1); + smpl = nullptr; + + // TODO: optimize or pass from outside? + // { + // common_params_sampling params; + // params.no_perf = false; + // + // params.top_k = 40; + // params.top_p = 0.9; + // + // params.samplers = { + // COMMON_SAMPLER_TYPE_TOP_K, + // COMMON_SAMPLER_TYPE_TOP_P, + // COMMON_SAMPLER_TYPE_INFILL, + // }; + // + // result->smpl = common_sampler_init(llama_get_model(ctx_dft), params); + // } + { + common_params_sampling params; + params.no_perf = false; + params.top_k = 10; + params.samplers = { + COMMON_SAMPLER_TYPE_TOP_K, + }; + + smpl = common_sampler_init(llama_get_model(ctx_dft), params); } - if ((cur >= params.n_reuse || n_ctx >= (int) prompt_tgt.size()) && cur > reuse_n) { - reuse_i = i; - reuse_n = cur; + vocab_cmpt = common_speculative_are_compatible(llama_get_model(ctx_tgt), llama_get_model(ctx_dft)); + LOG_DBG("vocab_cmpt = %d\n", vocab_cmpt); + + if (!vocab_cmpt) { + LOG_WRN("the target and draft vocabs are not compatible - tokens will be translated between the two\n"); + + for (const auto & pair : replacements) { + vocab_map[pair.first] = pair.second; + } } } - LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); + ~common_speculative_state_draft() override { + llama_perf_context_print(ctx_dft); - llama_tokens result; - result.reserve(params.n_draft); + llama_free(ctx_dft); - if (reuse_n == 0) { - llama_memory_clear(mem_dft, false); - prompt_dft.clear(); - } else { - // this happens when a previous draft has been discarded (for example, due to being too small), but the - // target model agreed with it. in this case, we simply pass back the previous results to save compute - if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { - for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { - result.push_back(prompt_dft[i]); + common_sampler_free(smpl); - if (params.n_draft <= (int) result.size()) { - break; - } + llama_batch_free(batch); + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + auto * spec = this; + + auto & batch = spec->batch; + auto & ctx_tgt = spec->ctx_tgt; + auto & ctx_dft = spec->ctx_dft; + auto & smpl = spec->smpl; + auto & prompt_dft = spec->prompt_dft; + + auto * mem_dft = llama_get_memory(ctx_dft); + + int reuse_i = 0; + int reuse_n = 0; + + const int n_ctx = llama_n_ctx(ctx_dft) - params.n_max; + + llama_tokens prompt_cnv; + if (!spec->vocab_cmpt) { + std::string text; + + text = common_detokenize(ctx_tgt, prompt_tgt, true); + text = replace_to_dft(text); + + LOG_DBG("%s: main->draft detokenized string: '%s'\n", __func__, text.c_str()); + + prompt_cnv = common_tokenize(ctx_dft, text, false, true); + + // convert id_last to draft vocab. llama_detokenize is called directly to avoid an allocation + const auto * model_tgt = llama_get_model(ctx_tgt); + const auto * vocab_tgt = llama_model_get_vocab(model_tgt); + + int32_t n_chars = llama_detokenize(vocab_tgt, &id_last, 1, nullptr, 0, false, false); + GGML_ASSERT(n_chars < 0 && "failed to detokenize id_last"); + + text.resize(-n_chars); + llama_detokenize(vocab_tgt, &id_last, 1, text.data(), text.size(), false, false); + text = replace_to_dft(text); + + LOG_DBG("main->draft detokenized id_last(%d): '%s'\n", id_last, text.c_str()); + id_last = common_tokenize(ctx_dft, text, false, true)[0]; + } + + const llama_tokens & prompt_cur = spec->vocab_cmpt ? prompt_tgt : prompt_cnv; + + const int i_start = std::max(0, (int) prompt_cur.size() - n_ctx); + + // reuse as much as possible from the old draft context + // ideally, the draft context should be as big as the target context and we will always reuse the entire prompt + for (int i = 0; i < (int) prompt_dft.size(); ++i) { + int cur = 0; + while (i_start + cur < (int) prompt_cur.size() && + i + cur < (int) prompt_dft.size() && + prompt_cur[i_start + cur] == prompt_dft[i + cur]) { + cur++; } - return result; + if ((cur >= 256 || n_ctx >= (int) prompt_cur.size()) && cur > reuse_n) { + reuse_i = i; + reuse_n = cur; + } } - if (reuse_i > 0) { - llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); - llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); + LOG_DBG("%s: reuse_i = %d, reuse_n = %d, prompt = %d\n", __func__, reuse_i, reuse_n, (int) prompt_dft.size()); - prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + result.clear(); + result.reserve(params.n_max); + + if (reuse_n == 0) { + llama_memory_clear(mem_dft, false); + prompt_dft.clear(); + } else { + // this happens when a previous draft has been discarded (for example, due to being too small), but the + // target model agreed with it. in this case, we simply pass back the previous results to save compute + if (reuse_i + reuse_n < (int) prompt_dft.size() && prompt_dft[reuse_i + reuse_n] == id_last) { + for (int i = reuse_i + reuse_n + 1; i < (int) prompt_dft.size(); ++i) { + result.push_back(prompt_dft[i]); + + if (params.n_max <= (int) result.size()) { + break; + } + } + + return; + } + + if (reuse_i > 0) { + llama_memory_seq_rm (mem_dft, 0, 0, reuse_i); + llama_memory_seq_add(mem_dft, 0, reuse_i, -1, -reuse_i); + + prompt_dft.erase(prompt_dft.begin(), prompt_dft.begin() + reuse_i); + } + + if (reuse_n < (int) prompt_dft.size()) { + llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); + prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); + } } - if (reuse_n < (int) prompt_dft.size()) { - llama_memory_seq_rm (mem_dft, 0, reuse_n, -1); - prompt_dft.erase(prompt_dft.begin() + reuse_n, prompt_dft.end()); - } - } - - // prepare a batch to evaluate any new tokens in the prompt - common_batch_clear(batch); - - for (size_t i = i_start + reuse_n; i < prompt_tgt.size(); ++i) { - //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_tgt[i]); - common_batch_add(batch, prompt_tgt[i], i - i_start, { 0 }, false); - - prompt_dft.push_back(prompt_tgt[i]); - } - - // we should rarely end-up here during normal decoding - if (batch.n_tokens > 0) { - //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - - llama_decode(ctx_dft, batch); - } - - const llama_pos n_past = prompt_dft.size(); - - LOG_DBG("%s: n_past = %d\n", __func__, n_past); - - common_batch_clear(batch); - common_batch_add (batch, id_last, n_past, { 0 }, true); - - prompt_dft.push_back(id_last); - - LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - - llama_decode(ctx_dft, batch); - - common_sampler_reset(smpl); - - // sample n_draft tokens from the draft model - for (int i = 0; i < params.n_draft; ++i) { + // prepare a batch to evaluate any new tokens in the prompt common_batch_clear(batch); - common_sampler_sample(smpl, ctx_dft, 0, true); + for (size_t i = i_start + reuse_n; i < prompt_cur.size(); ++i) { + //LOG_DBG("i = %d, i_start = %d, reuse_n = %d, i - i_start = %d, id = %6d\n", i, i_start, reuse_n, i - i_start, prompt_cur[i]); + common_batch_add(batch, prompt_cur[i], i - i_start, { 0 }, false); - const auto * cur_p = common_sampler_get_candidates(smpl, true); - - for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + prompt_dft.push_back(prompt_cur[i]); } - // add drafted token for each sequence - const llama_token id = cur_p->data[0].id; + // we should rarely end-up here during normal decoding + if (batch.n_tokens > 0) { + //LOG_DBG("%s: draft prompt batch: %s\n", __func__, string_from(ctx, batch).c_str()); - common_sampler_accept(smpl, id, true); - - result.push_back(id); - - if (params.n_draft <= (int) result.size()) { - break; + llama_decode(ctx_dft, batch); } - // only collect very high-confidence draft tokens - if (cur_p->data[0].p < params.p_min) { - break; - } + const llama_pos n_past = prompt_dft.size(); - common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + LOG_DBG("%s: n_past = %d\n", __func__, n_past); + + common_batch_clear(batch); + common_batch_add (batch, id_last, n_past, { 0 }, true); + + prompt_dft.push_back(id_last); + + LOG_DBG("%s: draft prompt: %s\n", __func__, string_from(ctx_dft, prompt_dft).c_str()); - // evaluate the drafted tokens on the draft model llama_decode(ctx_dft, batch); - prompt_dft.push_back(id); + common_sampler_reset(smpl); + + // sample n_draft tokens from the draft model + for (int i = 0; i < params.n_max; ++i) { + common_batch_clear(batch); + + common_sampler_sample(smpl, ctx_dft, 0, true); + + const auto * cur_p = common_sampler_get_candidates(smpl, true); + + for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { + LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", + k, i, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx_dft, cur_p->data[k].id).c_str()); + } + + // add drafted token for each sequence + const llama_token id = cur_p->data[0].id; + + common_sampler_accept(smpl, id, true); + + result.push_back(id); + + if (params.n_max <= (int) result.size()) { + break; + } + + // only collect very high-confidence draft tokens + if (cur_p->data[0].p < params.p_min) { + break; + } + + common_batch_add(batch, id, n_past + i + 1, { 0 }, true); + + // evaluate the drafted tokens on the draft model + llama_decode(ctx_dft, batch); + + prompt_dft.push_back(id); + } + + if (!spec->vocab_cmpt) { + std::string detokenized = common_detokenize(ctx_dft, result, true); + detokenized = replace_to_tgt(detokenized); + LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); + result = common_tokenize(ctx_tgt, detokenized, false, true); + if (result.size() > (size_t)params.n_max) { + result.resize(params.n_max); + } + } } - if (!spec->vocab_dft_compatible) { - std::string detokenized = common_detokenize(ctx_dft, result, true); - detokenized = replace_to_tgt(spec, detokenized); - LOG_DBG("draft->main detokenized string: '%s'\n", detokenized.c_str()); - result = common_tokenize(ctx_tgt, detokenized, false, true); - if (result.size() > (size_t)params.n_draft) { - result.resize(params.n_draft); + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } + + std::string replace_to_dft(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.first); + while (pos != std::string::npos) { + result.replace(pos, pair.first.length(), pair.second); + pos = result.find(pair.first, pos + pair.second.length()); + } } + + return result; + } + + std::string replace_to_tgt(const std::string & input) const { + std::string result = input; + + for (const auto & pair : this->vocab_map) { + size_t pos = result.find(pair.second); + while (pos != std::string::npos) { + result.replace(pos, pair.second.length(), pair.first); + pos = result.find(pair.second, pos + pair.first.length()); + } + } + + return result; + } +}; + +struct common_speculative_state_eagle3 : public common_speculative_state { + common_speculative_state_eagle3(enum common_speculative_type type) : common_speculative_state(type) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & draft_tokens) override { + // TODO: implement + GGML_UNUSED(params); + GGML_UNUSED(prompt_tgt); + GGML_UNUSED(id_last); + GGML_UNUSED(draft_tokens); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +// state of self-speculation (simple implementation, not ngram-map) +struct common_speculative_state_ngram_simple : public common_speculative_state { + common_ngram_simple_state state; + + common_speculative_state_ngram_simple( + enum common_speculative_type type, + common_ngram_simple_state state) + : common_speculative_state(type), state(state) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + result = common_ngram_simple_draft(state, prompt_tgt, id_last); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + // noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative_state_ngram_map_k : public common_speculative_state { + // draft ngram map for speculative decoding without draft model + common_ngram_map map; + + common_speculative_state_ngram_map_k( + enum common_speculative_type type, + common_ngram_map map) + : common_speculative_state(type), map(std::move(map)) {} + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + common_ngram_map_draft(map, prompt_tgt, id_last, result); + GGML_UNUSED(params); + } + + void accept(uint16_t n_accepted) override { + common_ngram_map_accept(map, n_accepted); + } +}; + +struct common_speculative_state_ngram_cache : public common_speculative_state { + uint16_t n_draft; + bool save_dynamic; + bool save_static; + + common_ngram_cache ngram_cache_context; + common_ngram_cache ngram_cache_dynamic; + common_ngram_cache ngram_cache_static; + + size_t cache_size = 0; // number of tokens in n-gram cache + + common_speculative_state_ngram_cache( + const enum common_speculative_type type, + const std::string & path_static, + const std::string & path_dynamic, + uint16_t n_draft, + bool save_dynamic, + bool save_static) + : common_speculative_state(type) + , n_draft(n_draft) + , save_dynamic(save_dynamic) + , save_static(save_static) + { + if (!path_static.empty()) { + try { + ngram_cache_static = common_ngram_cache_load(path_static); + } catch (...) { + LOG_ERR("failed to open static lookup cache: %s", path_static.c_str()); + GGML_ABORT("Couldn't read static lookup cache"); + } + } + + if (!path_dynamic.empty()) { + try { + ngram_cache_dynamic = common_ngram_cache_load(path_dynamic); + } catch (...) { + LOG_ERR("failed to open dynamic lookup cache: %s", path_dynamic.c_str()); + GGML_ABORT("Couldn't read dynamic lookup cache"); + } + } + } + + void begin(const llama_tokens & prompt) override { + GGML_UNUSED(prompt); + } + + void draft( + const common_params_speculative & params, + const llama_tokens & prompt_tgt, + llama_token id_last, + llama_tokens & result) override { + GGML_UNUSED(params); + + if (cache_size < prompt_tgt.size() + 1) { + llama_tokens tokens_new; + tokens_new.reserve(prompt_tgt.size() + 1 - cache_size); + for (size_t j = cache_size; j < prompt_tgt.size(); ++j) { + tokens_new.push_back(prompt_tgt[j]); + } + tokens_new.push_back(id_last); // add the last token + + // Update context ngram cache with new prompt_tgt: + common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + tokens_new, tokens_new.size(), false); + cache_size = prompt_tgt.size() + 1; + } + + llama_tokens inp; + inp.reserve(prompt_tgt.size() + 1); + for (size_t j = 0; j < prompt_tgt.size(); ++j) { + inp.push_back(prompt_tgt[j]); + } + inp.push_back(id_last); + + result.push_back(id_last); + + common_ngram_cache_draft(inp, result, n_draft, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, + ngram_cache_context, + ngram_cache_dynamic, + ngram_cache_static); + + if (result.size() > 0) { + // delete first token in result (which is the id_last token) + result.erase(result.begin()); + } + } + + void accept(uint16_t n_accepted) override { + // TODO: noop + GGML_UNUSED(n_accepted); + } +}; + +struct common_speculative { + std::vector> impls; // list of implementations to use and their states + common_speculative_state * curr_impl = nullptr; // current implementation in use (for stats) +}; + +static common_ngram_map get_common_ngram_map(const common_speculative_config & config) { + uint16_t size_key = config.params.ngram_size_n; + uint16_t size_value = config.params.ngram_size_m; + bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + uint16_t check_rate = config.params.ngram_check_rate; + uint16_t min_hits = config.params.ngram_min_hits; + + return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits); +} + +static common_speculative_state_ngram_cache create_state_ngram_cache( + const std::string & path_static, const std::string & path_dynamic, + const common_speculative_config & config) { + uint16_t n_draft = 8; // TODO get from config? + + // TODO bool param in common/common.h to set save_static/save_dynamic? + bool save_static = false; + bool save_dynamic = false; + + common_speculative_state_ngram_cache state(config.type, path_static, path_dynamic, n_draft, save_static, save_dynamic); + + return state; +} + +std::string common_speculative_type_name_str() { + std::string result; + for (size_t i = 0; i < common_speculative_types.size(); i++) { + if (i > 0) { + result += ", "; + } + result += common_speculative_type_to_str(common_speculative_types[i]); } return result; } + +std::string common_speculative_type_to_str(enum common_speculative_type type) { + switch (type) { + case COMMON_SPECULATIVE_TYPE_NONE: return "none"; + case COMMON_SPECULATIVE_TYPE_DRAFT: return "draft"; + case COMMON_SPECULATIVE_TYPE_EAGLE3: return "eagle3"; + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; + default: return "unknown"; + } +} + +enum common_speculative_type common_speculative_type_from_name(const std::string & name) { + const auto it = common_speculative_type_from_name_map.find(name); + if (it == common_speculative_type_from_name_map.end()) { + return COMMON_SPECULATIVE_TYPE_COUNT; + } + return it->second; +} + +// initialization of the speculative decoding system +// +common_speculative * common_speculative_init( + const common_params_speculative & params, + llama_context * ctx_tgt) { + llama_context * ctx_dft = nullptr; + if (params.model_dft) { + ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); + if (ctx_dft == nullptr) { + LOG_ERR("%s", "failed to create draft context\n"); + return nullptr; + } + } + + // Compute the implementations to use based on the config and their order of preference + std::vector configs = {}; // list of speculative configs to try + { + bool has_draft = !params.mparams_dft.path.empty(); + bool has_draft_eagle3 = false; // TODO PR-18039: if params.speculative.eagle3 + + bool has_ngram_cache = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_CACHE); + bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); + bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); + bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); + + // In a more complex implementation we could use the same implementation but with different parameters. + // This was initially used in PR-18471 but removed to simplify the code. + if (has_ngram_simple) { + // This implementation can guess a lot of tokens without any draft model. + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, params)); + } + if (has_ngram_map_k) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, params)); + } + if (has_ngram_map_k4v) { + // This implementation can guess tokens with high acceptance rate but is more expensive. + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); + } + if (has_ngram_cache) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); + } + if (has_draft) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_DRAFT, params)); + } + if (has_draft_eagle3) { + configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_EAGLE3, params)); + } + } + + std::vector> impls = {}; + + for (const common_speculative_config & config : configs) { + LOG_DBG("%s: adding implementation %s\n", __func__, common_speculative_type_to_str(config.type).c_str()); + switch (config.type) { + case COMMON_SPECULATIVE_TYPE_NONE: + break; + case COMMON_SPECULATIVE_TYPE_DRAFT: { + impls.push_back(std::make_unique(config.type, + /* .ctx_tgt = */ ctx_tgt, + /* .ctx_dft = */ ctx_dft, + /* .replacements = */ params.replacements + )); + break; + } + case COMMON_SPECULATIVE_TYPE_EAGLE3: { + impls.push_back(std::make_unique(config.type)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: { + common_ngram_map ngram_map = get_common_ngram_map(config); + + uint16_t ngram_size_key = ngram_map.size_key; + uint16_t mgram_size_value = ngram_map.size_value; + uint16_t check_rate = ngram_map.check_rate; + + auto config_simple = common_ngram_simple_config{ + /* .size_ngram = */ ngram_size_key, + /* .size_mgram = */ mgram_size_value, + /* .check_rate = */ check_rate + }; + auto state = std::make_unique( + /* .type = */ config.type, + /* .state = */ common_ngram_simple_state(config_simple) + ); + impls.push_back(std::move(state)); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: + case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: { + impls.push_back(std::make_unique( + (config.type), + get_common_ngram_map(config) + )); + break; + } + case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { + auto state = create_state_ngram_cache( + params.lookup_cache_static, params.lookup_cache_dynamic, config); + impls.push_back(std::make_unique(state)); + break; + } + default: + break; + } + } + + if (impls.empty()) { + LOG_WRN("%s", "no implementations specified for speculative decoding\n"); + return nullptr; + } + + auto * result = new common_speculative { + /* .impls = */ std::move(impls) + }; + + return result; +} + +void common_speculative_free(common_speculative * spec) { + if (spec == nullptr) { + return; + } + + delete spec; +} + +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt) { + if (spec == nullptr) { + return; + } + + for (auto & impl : spec->impls) { + impl->begin(prompt); + } +} + +llama_tokens common_speculative_draft( + common_speculative * spec, + const common_params_speculative & params, + const llama_tokens & prompt_tgt, // specified in target model vocab + llama_token id_last) { + llama_tokens result; + + spec->curr_impl = nullptr; // reset current implementation + + for (auto & impl : spec->impls) { + { + const int64_t t_start_us = impl->gen_perf ? ggml_time_us() : 0; + + impl->draft(params, prompt_tgt, id_last, result); + + const int64_t t_now_us = impl->gen_perf ? ggml_time_us() : 0; + + impl->drafts_call_count++; + impl->gen_duration_us += t_now_us - t_start_us; // accumulate duration for this implementation + } + + if (!result.empty()) { + LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, + common_speculative_type_to_str(impl.get()->type).c_str(), + prompt_tgt.size(), + impl.get()->drafts_call_count, result.size()); + + spec->curr_impl = impl.get(); // set current implementation for stats + impl->drafts_generated_count++; + impl->drafts_generated_tokens += result.size(); + + break; // We have a draft, so break out of the loop and return it. + } + } + + return result; +} + +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { + if (n_accepted == 0) { + return; + } + + common_speculative_state * impl = spec->curr_impl; + + GGML_ASSERT(impl); + + if (n_accepted > 0) { + impl->drafts_accepted_count++; + impl->drafts_accepted_tokens += n_accepted; + } + + impl->accept(n_accepted); +} + +void common_speculative_print_stats(const common_speculative * spec) { + if (spec == nullptr) { + return; + } + + for (const auto & impl : spec->impls) { + std::string str_perf; + if (impl->gen_perf) { + std::ostringstream oss; + oss << std::fixed << std::setprecision(3) << impl->gen_duration_us / 1000.0; + str_perf = ", dur = " + oss.str() + " ms"; + } else { + str_perf = ""; + } + + LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", + common_speculative_type_to_str(impl->type).c_str(), + impl->drafts_call_count, + impl->drafts_generated_count, + impl->drafts_accepted_count, + impl->drafts_generated_tokens, + impl->drafts_accepted_tokens, + str_perf.c_str()); + } +} diff --git a/common/speculative.h b/common/speculative.h index e69d7aaa1e..9e1888e4be 100644 --- a/common/speculative.h +++ b/common/speculative.h @@ -5,31 +5,33 @@ struct common_speculative; -struct common_speculative_params { - int n_draft = 16; // max drafted tokens - int n_reuse = 256; +// comma separated list of all types +std::string common_speculative_type_name_str(); - float p_min = 0.75f; // min probability required to accept a token in the draft -}; +// convert string to type +enum common_speculative_type common_speculative_type_from_name(const std::string & name); -struct common_speculative * common_speculative_init( - struct llama_context * ctx_tgt, - struct llama_context * ctx_dft -); +// convert type to string +std::string common_speculative_type_to_str(enum common_speculative_type type); -void common_speculative_free(struct common_speculative * spec); +common_speculative * common_speculative_init( + const common_params_speculative & params, + llama_context * ctx_tgt); -bool common_speculative_are_compatible( - const struct llama_context * ctx_tgt, - const struct llama_context * ctx_dft); +void common_speculative_free(common_speculative * spec); -void common_speculative_add_replacement_tgt_dft( - struct common_speculative * spec, - const char *source, const char *dest); +// optionally call once at the beginning of a new generation +void common_speculative_begin(common_speculative * spec, const llama_tokens & prompt); // sample up to n_draft tokens and add them to the batch using the draft model -llama_tokens common_speculative_gen_draft( - struct common_speculative * spec, - struct common_speculative_params params, - const llama_tokens & prompt, - llama_token id_last); +llama_tokens common_speculative_draft( + common_speculative * spec, + const common_params_speculative & params, + const llama_tokens & prompt, + llama_token id_last); + +// informs the speculative decoder that n_accepted tokens were accepted by the target model +void common_speculative_accept(common_speculative * spec, uint16_t n_accepted); + +// print statistics about the speculative decoding +void common_speculative_print_stats(const common_speculative * spec); diff --git a/docs/speculative.md b/docs/speculative.md new file mode 100644 index 0000000000..8281eaa2d3 --- /dev/null +++ b/docs/speculative.md @@ -0,0 +1,120 @@ +# Speculative Decoding + +llama.cpp supports speculative decoding, a technique that can significantly accelerate token generation by predicting multiple tokens ahead of the main model. + +[Speculative decoding](https://en.wikipedia.org/wiki/Transformer_(deep_learning)#Speculative_decoding) leverages the fact that computing n tokens in a batch (as in prompt processing) is more efficient than computing n sequentially (as in response generation). By generating draft tokens quickly and then verifying them with the target model in a single batch, this approach can achieve substantial speedups when the draft predictions are frequently correct. + +## Implementations + +The `llama-server` application supports several implementations of speculative decoding: + +### Draft Model (`draft`) + +A much smaller model (called the _draft model_) generates drafts. +A draft model is the most used approach in speculative decoding. + +### n-gram Cache (`ngram-cache`) + +An n-gram is a sequence of n tokens. The n-gram cache implementation maintains statistics about short n-gram sequences. +A draft is computed using probabilities derived from these statistics. External statistics can also be loaded from files for improved accuracy. + +See: + +- #5479, #6828, #6848 + +### n-gram Map (`ngram-simple`, `ngram-map-*`) + +These implementations search the token history for patterns and use matching sequences as draft candidates. +They require no additional model but rely on patterns that have already appeared in the generated text. +An example to use this approach can be the rewriting of source code by a LLM. + +#### n-gram Map (`ngram-simple`) + +This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead. + +#### n-gram Map Key (`ngram-map-k`) + +This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts. + +The number of accepted tokens is stored for each used n-gram. + +#### n-gram Map Key-4-Values (`ngram-map-k4v`) + +This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft. + +The number of accepted tokens is stored for each used n-gram. + +**Example:** Server options to be used if there are a lot of longer repetitions. +```bash +llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 +``` + + +## Command-Line Options + +If a draft model is combined with a draftless decoding the draftless decoding has higher precedence. + +``` +--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v] + type of speculative decoding to use when no draft model is provided + (default: none) +--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length + of lookup n-gram (default: 12) +--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length + of draft m-gram (default: 48) +--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding + (default: 1) +--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1) +``` + +### `--spec-type TYPE` + +Specifies a type of speculative decoding without draft model. + +| Type | Description | +|------|-------------| +| `none` | No speculative decoding (default) | +| `ngram-cache` | Use n-gram cache lookup | +| `ngram-simple` | Use simple n-gram pattern matching | +| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys | +| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) | + +**Example:** Server-instance used to refactor source code. +```bash +./llama-server [...] --spec-type ngram-simple +``` + +### `--spec-ngram-size-n N` + +Sets the size N of the lookup n-gram for n-gram map based speculative decoding. +The n-gram size N determines how many tokens in a row to look back when searching for matching patterns. + +### `--spec-ngram-size-m M` + +Sets the size M of the draft m-gram for n-gram map based speculative decoding. +The m-gram size determines how many tokens to draft when a match is found. +Larger values can provide more speedup but may reduce acceptance rate. + +### `--spec-ngram-check-rate R` + +This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token). + +### `--spec-ngram-min-hits H` + +This option defines how often a key has to appear in the token history to be used as a draft (default is 1). + +## Statistics +Each speculative decoding implementation prints statistics. + +``` +draft acceptance rate = 0.57576 ( 171 accepted / 297 generated) +statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tokens = 187, #acc tokens = 73 +statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98 +``` + +- `#calls`: number of calls of this implementations +- `#gen drafts`: number of drafts generated by this implementation +- `#acc drafts`: number of drafts accepted (partially) by the main model +- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens) +- `#acc tokens`: number of tokens accepted by the main model + diff --git a/examples/lookup/lookup-create.cpp b/examples/lookup/lookup-create.cpp index bb94a8fe06..f7b6ea1b19 100644 --- a/examples/lookup/lookup-create.cpp +++ b/examples/lookup/lookup-create.cpp @@ -32,9 +32,9 @@ int main(int argc, char ** argv){ common_ngram_cache ngram_cache; common_ngram_cache_update(ngram_cache, LLAMA_NGRAM_STATIC, LLAMA_NGRAM_STATIC, inp, inp.size(), true); - fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.lookup_cache_static.c_str()); + fprintf(stderr, "%s: hashing done, writing file to %s\n", __func__, params.speculative.lookup_cache_static.c_str()); - common_ngram_cache_save(ngram_cache, params.lookup_cache_static); + common_ngram_cache_save(ngram_cache, params.speculative.lookup_cache_static); return 0; } diff --git a/examples/lookup/lookup-stats.cpp b/examples/lookup/lookup-stats.cpp index 135f6fcab9..ae28b2e6e8 100644 --- a/examples/lookup/lookup-stats.cpp +++ b/examples/lookup/lookup-stats.cpp @@ -46,18 +46,18 @@ int main(int argc, char ** argv){ { const int64_t t_start_draft_us = ggml_time_us(); - if (!params.lookup_cache_static.empty()) { + if (!params.speculative.lookup_cache_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static); + ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static); } catch (std::ifstream::failure const &) { - LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); + LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str()); exit(1); } } - if (!params.lookup_cache_dynamic.empty()) { + if (!params.speculative.lookup_cache_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic); + ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic); } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program } diff --git a/examples/lookup/lookup.cpp b/examples/lookup/lookup.cpp index 27f159940a..8e73138a5f 100644 --- a/examples/lookup/lookup.cpp +++ b/examples/lookup/lookup.cpp @@ -51,18 +51,18 @@ int main(int argc, char ** argv){ const int64_t t_start_draft_us = ggml_time_us(); common_ngram_cache_update(ngram_cache_context, LLAMA_NGRAM_MIN, LLAMA_NGRAM_MAX, inp, inp.size(), false); - if (!params.lookup_cache_static.empty()) { + if (!params.speculative.lookup_cache_static.empty()) { try { - ngram_cache_static = common_ngram_cache_load(params.lookup_cache_static); + ngram_cache_static = common_ngram_cache_load(params.speculative.lookup_cache_static); } catch (std::ifstream::failure const &) { - LOG_ERR("failed to open static lookup cache: %s", params.lookup_cache_static.c_str()); + LOG_ERR("failed to open static lookup cache: %s", params.speculative.lookup_cache_static.c_str()); exit(1); } } - if (!params.lookup_cache_dynamic.empty()) { + if (!params.speculative.lookup_cache_dynamic.empty()) { try { - ngram_cache_dynamic = common_ngram_cache_load(params.lookup_cache_dynamic); + ngram_cache_dynamic = common_ngram_cache_load(params.speculative.lookup_cache_dynamic); } catch (std::ifstream::failure const &) {} // if the file does not exist it will simply be created at the end of the program } @@ -210,7 +210,7 @@ int main(int argc, char ** argv){ // Update dynamic ngram cache with context ngram cache and save it to disk: common_ngram_cache_merge(ngram_cache_dynamic, ngram_cache_context); - common_ngram_cache_save(ngram_cache_dynamic, params.lookup_cache_dynamic); + common_ngram_cache_save(ngram_cache_dynamic, params.speculative.lookup_cache_dynamic); LOG("\n\n"); diff --git a/examples/speculative-simple/speculative-simple.cpp b/examples/speculative-simple/speculative-simple.cpp index 8141052a22..d8b1f5a480 100644 --- a/examples/speculative-simple/speculative-simple.cpp +++ b/examples/speculative-simple/speculative-simple.cpp @@ -24,7 +24,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.path.empty()) { + if (params.speculative.mparams_dft.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -34,10 +34,8 @@ int main(int argc, char ** argv) { llama_numa_init(params.numa); llama_model * model_tgt = NULL; - //llama_model * model_dft = NULL; llama_context * ctx_tgt = NULL; - llama_context * ctx_dft = NULL; // load the target model auto llama_init_tgt = common_init_from_params(params); @@ -48,26 +46,38 @@ int main(int argc, char ** argv) { const llama_vocab * vocab = llama_model_get_vocab(model_tgt); // load the draft model - params.devices = params.speculative.devices; - params.model = params.speculative.model; - params.n_ctx = params.speculative.n_ctx; - params.n_batch = params.speculative.n_ctx > 0 ? params.speculative.n_ctx : params.n_batch; - params.n_gpu_layers = params.speculative.n_gpu_layers; + llama_model_ptr model_dft; - if (params.speculative.cpuparams.n_threads > 0) { - params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; - } + // TODO: simplify this logic + { + const auto & params_spec = params.speculative; - params.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; - params.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + auto params_dft = params; - auto llama_init_dft = common_init_from_params(params); + params_dft.n_parallel = 1; + params_dft.n_ctx = params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx_tgt); + params_dft.devices = params_spec.devices; + params_dft.model = params_spec.mparams_dft; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; - //model_dft = llama_init_dft->model(); - ctx_dft = llama_init_dft->context(); + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params.speculative.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params.speculative.cpuparams_batch.n_threads; + } - if (!common_speculative_are_compatible(ctx_tgt, ctx_dft)) { - LOG_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params.speculative.model.path.c_str(), params.model.path.c_str()); + params_dft.tensor_buft_overrides = params.speculative.tensor_buft_overrides; + + auto mparams_dft = common_model_params_to_llama(params_dft); + + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); + if (model_dft == nullptr) { + LOG_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); + return 1; + } + + params.speculative.model_dft = model_dft.get(); + params.speculative.cparams_dft = common_context_params_to_llama(params_dft); } // Tokenize the prompt @@ -92,12 +102,6 @@ int main(int argc, char ** argv) { LOG("%s", common_token_to_piece(ctx_tgt, id).c_str()); } - // how many tokens to draft each time - int n_draft = params.speculative.n_max; - int n_draft_min = params.speculative.n_min; - - float p_min = params.speculative.p_min; - int n_predict = 0; int n_drafted = 0; int n_accept = 0; @@ -127,15 +131,11 @@ int main(int argc, char ** argv) { int n_past = inp.size() - 1; // init the speculator - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft; - params_spec.n_reuse = llama_n_ctx(ctx_dft) - n_draft; - params_spec.p_min = p_min; + const auto & params_spec = params.speculative; - struct common_speculative * spec = common_speculative_init(ctx_tgt, ctx_dft); - for (auto &pair : params.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(spec, pair.first.c_str(), pair.second.c_str()); - } + struct common_speculative * spec = common_speculative_init(params.speculative, ctx_tgt); + + common_speculative_begin(spec, prompt_tgt); llama_batch batch_tgt = llama_batch_init(llama_n_batch(ctx_tgt), 0, 1); @@ -151,7 +151,7 @@ int main(int argc, char ** argv) { // offloaded to a remote device. it doesn't even have to be based on an LLM. instead, it can provide tokens // from a cache or lookup tables. // - llama_tokens draft = common_speculative_gen_draft(spec, params_spec, prompt_tgt, id_last); + llama_tokens draft = common_speculative_draft(spec, params_spec, prompt_tgt, id_last); //LOG_DBG("draft: %s\n", string_from(ctx_dft, draft).c_str()); @@ -162,7 +162,7 @@ int main(int argc, char ** argv) { // evaluate the target model on [id_last, draft0, draft1, ..., draftN-1] { // do not waste time on small drafts - if (draft.size() < (size_t) n_draft_min) { + if (draft.size() < (size_t) params_spec.n_min) { draft.clear(); } @@ -240,7 +240,7 @@ int main(int argc, char ** argv) { LOG_INF("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f)); LOG_INF("\n"); - LOG_INF("n_draft = %d\n", n_draft); + LOG_INF("n_draft = %d\n", params_spec.n_max); LOG_INF("n_predict = %d\n", n_predict); LOG_INF("n_drafted = %d\n", n_drafted); LOG_INF("n_accept = %d\n", n_accept); @@ -249,8 +249,6 @@ int main(int argc, char ** argv) { LOG_INF("\n"); LOG_INF("draft:\n\n"); - llama_perf_context_print(ctx_dft); - LOG_INF("\n"); LOG_INF("target:\n\n"); common_perf_print(ctx_tgt, smpl); diff --git a/examples/speculative/speculative.cpp b/examples/speculative/speculative.cpp index 89d3249431..3e5cf5f46b 100644 --- a/examples/speculative/speculative.cpp +++ b/examples/speculative/speculative.cpp @@ -46,7 +46,7 @@ int main(int argc, char ** argv) { common_init(); - if (params.speculative.model.path.empty()) { + if (params.speculative.mparams_dft.path.empty()) { LOG_ERR("%s: --model-draft is required\n", __func__); return 1; } @@ -78,7 +78,7 @@ int main(int argc, char ** argv) { // load the draft model params.devices = params.speculative.devices; - params.model = params.speculative.model; + params.model = params.speculative.mparams_dft; params.n_gpu_layers = params.speculative.n_gpu_layers; if (params.speculative.cpuparams.n_threads > 0) { params.cpuparams.n_threads = params.speculative.cpuparams.n_threads; diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 73cb4c75b3..1ca4e3cc0e 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -48,11 +48,8 @@ enum server_state { struct server_slot { int id; - llama_batch batch_spec = {}; - // TODO: change to unique_ptrs for consistency: llama_context * ctx = nullptr; - llama_context * ctx_dft = nullptr; // multimodal mtmd_context * mctx = nullptr; @@ -259,7 +256,7 @@ struct server_slot { } bool can_speculate() const { - return ctx_dft; + return !!spec; } void add_token(const completion_token_output & token) { @@ -295,6 +292,7 @@ struct server_slot { SLT_DBG(*this, "the max possible draft is too small: %d < %d - skipping speculative decoding\n", n_draft_max, task->params.speculative.n_min); n_draft_max = 0; } + return n_draft_max; } @@ -397,6 +395,8 @@ struct server_slot { draft_ratio, n_draft_accepted, n_draft_total ); } + + common_speculative_print_stats(spec); } json to_json(bool only_metrics = false) const { @@ -553,18 +553,13 @@ private: // note: keep these alive - they determine the lifetime of the model, context, etc. common_init_result_ptr llama_init; - common_init_result_ptr llama_init_dft; llama_context * ctx = nullptr; - bool vocab_dft_compatible = true; - - llama_model * model_dft = nullptr; - - llama_context_params cparams_dft; - llama_batch batch {}; + llama_model_ptr model_dft; + bool add_bos_token = true; int32_t n_ctx; // total context for all clients / slots @@ -597,13 +592,8 @@ private: // Clear any sampling context for (server_slot & slot : slots) { - llama_free(slot.ctx_dft); - slot.ctx_dft = nullptr; - common_speculative_free(slot.spec); slot.spec = nullptr; - - llama_batch_free(slot.batch_spec); } llama_batch_free(batch); @@ -648,44 +638,39 @@ private: add_bos_token = llama_vocab_get_add_bos(vocab); - if (params_base.has_speculative()) { - SRV_INF("loading draft model '%s'\n", params_base.speculative.model.path.c_str()); + if (params_base.speculative.has_dft()) { + SRV_INF("loading draft model '%s'\n", params_base.speculative.mparams_dft.path.c_str()); + + const auto & params_spec = params_base.speculative; auto params_dft = params_base; - params_dft.devices = params_base.speculative.devices; - params_dft.model = params_base.speculative.model; - params_dft.n_ctx = params_base.speculative.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_base.speculative.n_ctx; - params_dft.n_gpu_layers = params_base.speculative.n_gpu_layers; params_dft.n_parallel = 1; - params_dft.cache_type_k = params_base.speculative.cache_type_k; - params_dft.cache_type_v = params_base.speculative.cache_type_v; + params_dft.n_ctx = params_spec.n_ctx == 0 ? llama_n_ctx_seq(ctx) : params_spec.n_ctx; + params_dft.n_batch = llama_n_ctx_seq(ctx); + params_dft.devices = params_spec.devices; + params_dft.model = params_spec.mparams_dft; + params_dft.n_gpu_layers = params_spec.n_gpu_layers; + params_dft.cache_type_k = params_spec.cache_type_k; + params_dft.cache_type_v = params_spec.cache_type_v; - params_dft.cpuparams.n_threads = params_base.speculative.cpuparams.n_threads; - params_dft.cpuparams_batch.n_threads = params_base.speculative.cpuparams_batch.n_threads; - params_dft.tensor_buft_overrides = params_base.speculative.tensor_buft_overrides; + if (params_spec.cpuparams.n_threads > 0) { + params_dft.cpuparams.n_threads = params_spec.cpuparams.n_threads; + params_dft.cpuparams_batch.n_threads = params_spec.cpuparams_batch.n_threads; + } - llama_init_dft = common_init_from_params(params_dft); + params_dft.tensor_buft_overrides = params_spec.tensor_buft_overrides; - model_dft = llama_init_dft->model(); + auto mparams_dft = common_model_params_to_llama(params_dft); + model_dft.reset(llama_model_load_from_file(params_dft.model.path.c_str(), mparams_dft)); if (model_dft == nullptr) { - SRV_ERR("failed to load draft model, '%s'\n", params_base.speculative.model.path.c_str()); + SRV_ERR("failed to load draft model, '%s'\n", params_dft.model.path.c_str()); return false; } - vocab_dft_compatible = common_speculative_are_compatible(ctx, llama_init_dft->context()); - if (!vocab_dft_compatible) { - SRV_INF("the draft model '%s' is not compatible with the target model '%s'. tokens will be translated between the draft and target models.\n", params_base.speculative.model.path.c_str(), params_base.model.path.c_str()); - } - - const int n_ctx_dft = llama_n_ctx(llama_init_dft->context()); - - cparams_dft = common_context_params_to_llama(params_dft); - cparams_dft.n_batch = n_ctx_dft; - - // the context is not needed - we will create one for each slot - llama_init_dft->free_context(); + params_base.speculative.model_dft = model_dft.get(); + params_base.speculative.cparams_dft = common_context_params_to_llama(params_dft); } std::string & mmproj_path = params_base.mmproj.path; @@ -695,6 +680,7 @@ private: } mtmd_context_params mparams = mtmd_context_params_default(); + mparams.use_gpu = params_base.mmproj_use_gpu; mparams.print_timings = false; mparams.n_threads = params_base.cpuparams.n_threads; @@ -702,6 +688,7 @@ private: mparams.warmup = params_base.warmup; mparams.image_min_tokens = params_base.image_min_tokens; mparams.image_max_tokens = params_base.image_max_tokens; + mctx = mtmd_init_from_file(mmproj_path.c_str(), model, mparams); if (mctx == nullptr) { SRV_ERR("failed to load multimodal model, '%s'\n", mmproj_path.c_str()); @@ -718,11 +705,6 @@ private: params_base.n_cache_reuse = 0; SRV_WRN("%s\n", "cache_reuse is not supported by multimodal, it will be disabled"); } - - if (params_base.has_speculative()) { - SRV_ERR("%s\n", "err: speculative decode is not supported by multimodal"); - return false; - } } if (!llama_memory_can_shift(llama_get_memory(ctx))) { @@ -757,29 +739,24 @@ private: for (int i = 0; i < params_base.n_parallel; i++) { server_slot slot; - slot.id = i; - slot.ctx = ctx; + slot.id = i; + slot.ctx = ctx; slot.n_ctx = n_ctx_slot; - slot.mctx = mctx; + + slot.mctx = mctx; slot.prompt.tokens.has_mtmd = mctx != nullptr; - if (model_dft) { - slot.batch_spec = llama_batch_init(params_base.speculative.n_max + 1, 0, 1); - - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - slot.ctx_dft = llama_init_from_model(model_dft, cparams_dft); - if (slot.ctx_dft == nullptr) { - SRV_ERR("%s", "failed to create draft context\n"); - return false; - } - - slot.spec = common_speculative_init(slot.ctx, slot.ctx_dft); - if (slot.spec == nullptr) { - SRV_ERR("%s", "failed to create speculator\n"); - return false; - } - for (auto & pair : params_base.speculative.replacements) { - common_speculative_add_replacement_tgt_dft(slot.spec, pair.first.c_str(), pair.second.c_str()); + // try speculative decoding + { + slot.spec = common_speculative_init(params_base.speculative, slot.ctx); + if (slot.spec) { + if (mctx) { + SRV_ERR("%s\n", "speculative decoding is not supported with multimodal"); + return false; + } + SRV_WRN("%s", "speculative decoding context initialized\n"); + } else { + SRV_WRN("%s", "speculative decoding context not initialized\n"); } } @@ -1059,7 +1036,7 @@ private: return res; } - std::vector construct_lora_list(const std::map & config) { + std::vector construct_lora_list(const std::map & config) const { std::vector output = params_base.lora_adapters; // copy for (size_t i = 0; i < output.size(); ++i) { auto it = config.find(i); @@ -1162,7 +1139,7 @@ private: backend_sampling &= task.params.sampling.backend_sampling; // TODO: speculative decoding requires multiple samples per batch - not supported yet - backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0); + backend_sampling &= !(slot.spec && task.params.speculative.n_max > 0); // TODO: getting post/pre sampling logits is not yet supported with backend sampling backend_sampling &= !need_logits; @@ -1179,14 +1156,6 @@ private: slot.smpl.reset(); } - // initialize draft batch - // TODO: rework speculative decoding [TAG_SERVER_SPEC_REWORK] - if (slot.ctx_dft) { - llama_batch_free(slot.batch_spec); - - slot.batch_spec = llama_batch_init(task.params.speculative.n_max + 1, 0, 1); - } - slot.task = std::make_unique(std::move(task)); slot.state = slot.task->is_child() @@ -2059,19 +2028,23 @@ private: // generate draft tokens in speculative decoding mode // TODO: rework to have a single draft llama_context shared across all slots [TAG_SERVER_SPEC_REWORK] // perform the speculative drafting for all sequences at the same time in a single batch - int n_draft_max = slot.get_n_draft_max(); + const int n_draft_max = slot.get_n_draft_max(); if (n_draft_max > 0) { if (mctx) { // we should never reach this, as speculative is automatically disabled if mmproj is loaded GGML_ABORT("not supported by multimodal"); } - struct common_speculative_params params_spec; - params_spec.n_draft = n_draft_max; - params_spec.n_reuse = llama_n_ctx(slot.ctx_dft) - slot.task->params.speculative.n_max; - params_spec.p_min = slot.task->params.speculative.p_min; const llama_tokens & cached_text_tokens = slot.prompt.tokens.get_text_tokens(); - llama_tokens draft = common_speculative_gen_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + + const auto & params_spec = slot.task->params.speculative; + + llama_tokens draft = common_speculative_draft(slot.spec, params_spec, cached_text_tokens, slot.sampled); + + if (draft.size() > (size_t) n_draft_max) { + SLT_WRN(slot, "draft size %d exceeds max %d, truncating\n", (int) draft.size(), n_draft_max); + draft.resize(n_draft_max); + } // add the sampled token to the batch slot.i_batch_dft.push_back(batch.n_tokens); @@ -2742,6 +2715,10 @@ private: // prompt evaluated for next-token prediction slot.state = SLOT_STATE_GENERATING; + + if (slot.can_speculate()) { + common_speculative_begin(slot.spec, slot.prompt.tokens.get_text_tokens()); + } } else if (slot.state != SLOT_STATE_GENERATING) { continue; // continue loop of slots } @@ -2813,6 +2790,9 @@ private: // update how many tokens out of those tested were accepted slot.n_draft_accepted += ids.size() - 1; + // inform the speculative decoding about the number of accepted tokens + common_speculative_accept(slot.spec, ids.size() - 1); + // rollback to the state before sampling the draft tokens slot.prompt.tokens.keep_first(slot.prompt.n_tokens() - n_draft); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 799e341d37..2d25db63b7 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -5,6 +5,7 @@ #include "llama.h" #include "chat.h" #include "sampling.h" +#include "speculative.h" #include "json-schema-to-grammar.h" using json = nlohmann::ordered_json; @@ -76,6 +77,11 @@ json task_params::to_json(bool only_metrics) const { {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, {"speculative.p_min", speculative.p_min}, + {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.ngram_size_n", speculative.ngram_size_n}, + {"speculative.ngram_size_m", speculative.ngram_size_m}, + {"speculative.ngram_c_rate", speculative.ngram_check_rate}, + {"speculative.ngram_m_hits", speculative.ngram_min_hits}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -135,6 +141,11 @@ json task_params::to_json(bool only_metrics) const { {"speculative.n_max", speculative.n_max}, {"speculative.n_min", speculative.n_min}, {"speculative.p_min", speculative.p_min}, + {"speculative.type", common_speculative_type_to_str(speculative.type)}, + {"speculative.ngram_size_n", speculative.ngram_size_n}, + {"speculative.ngram_size_m", speculative.ngram_size_m}, + {"speculative.ngram_c_rate", speculative.ngram_check_rate}, + {"speculative.ngram_m_hits", speculative.ngram_min_hits}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, {"backend_sampling", sampling.backend_sampling}, @@ -242,6 +253,18 @@ task_params server_task::params_from_json_cmpl( params.speculative.n_min = std::max(params.speculative.n_min, 0); params.speculative.n_max = std::max(params.speculative.n_max, 0); + params.speculative.type = common_speculative_type_from_name(json_value(data, "speculative.type", common_speculative_type_to_str(defaults.speculative.type))); + + params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n); + params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m); + params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate); + params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits); + + params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024); + params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024); + params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024); + params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024); + // Use OpenAI API logprobs only if n_probs wasn't provided if (data.contains("logprobs") && params.sampling.n_probs == defaults.sampling.n_probs){ params.sampling.n_probs = json_value(data, "logprobs", defaults.sampling.n_probs);