diff --git a/common/sampling.cpp b/common/sampling.cpp index a5824ebeed..452cefee3b 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -582,3 +582,7 @@ std::vector common_sampler_types_from_chars(const std::stri return samplers; } + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) { + llama_sampler_apply(gsmpl->chain, cur_p); +} \ No newline at end of file diff --git a/common/sampling.h b/common/sampling.h index 2064421db4..b424d7d6d7 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -105,3 +105,5 @@ std::vector common_sampler_types_from_chars(const std: llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab, const char * grammar_kind, const char * grammar_data); + +void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p); \ No newline at end of file diff --git a/common/speculative.cpp b/common/speculative.cpp index 5edd4aa815..77ed75913d 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -379,15 +379,22 @@ llama_token mtp_speculative_gen_draft( llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx); - llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true); + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_vocab = llama_n_vocab(vocab); - const auto * cur_p = common_sampler_get_candidates(smpl); - for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) { - LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", - k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str()); + llama_token_data_array * cur_p = common_sampler_get_candidates(smpl); + + cur_p->size = n_vocab; + for (int i = 0; i < n_vocab; ++i) { + cur_p->data[i].id = i; + cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i]; } + cur_p->sorted = false; - common_sampler_accept(smpl, id, true); + common_sampler_apply_chain(smpl, cur_p); + + const llama_token id = cur_p->data[0].id; llama_batch_free(batch);