feat: apply logits + greedy sampler

This commit is contained in:
samuel 2025-09-06 00:21:18 -03:00
parent 5a5bce8577
commit 8742ce0e39
3 changed files with 19 additions and 6 deletions

View File

@ -582,3 +582,7 @@ std::vector<common_sampler_type> common_sampler_types_from_chars(const std::stri
return samplers;
}
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p) {
llama_sampler_apply(gsmpl->chain, cur_p);
}

View File

@ -105,3 +105,5 @@ std::vector<enum common_sampler_type> common_sampler_types_from_chars(const std:
llama_sampler * llama_sampler_init_llg(const llama_vocab * vocab,
const char * grammar_kind, const char * grammar_data);
void common_sampler_apply_chain(struct common_sampler * gsmpl, struct llama_token_data_array * cur_p);

View File

@ -379,15 +379,22 @@ llama_token mtp_speculative_gen_draft(
llama_build_and_execute_mtp_graph(ctx, batch, id_last, n_past, last_tok_idx);
llama_token id = common_sampler_sample(smpl, ctx, last_tok_idx, true);
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_n_vocab(vocab);
const auto * cur_p = common_sampler_get_candidates(smpl);
for (int k = 0; k < std::min(3, (int)cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",
k, 0, cur_p->data[k].id, cur_p->data[k].p, common_token_to_piece(ctx, cur_p->data[k].id).c_str());
llama_token_data_array * cur_p = common_sampler_get_candidates(smpl);
cur_p->size = n_vocab;
for (int i = 0; i < n_vocab; ++i) {
cur_p->data[i].id = i;
cur_p->data[i].logit = llama_get_logits_ith(ctx, last_tok_idx)[i];
}
cur_p->sorted = false;
common_sampler_accept(smpl, id, true);
common_sampler_apply_chain(smpl, cur_p);
const llama_token id = cur_p->data[0].id;
llama_batch_free(batch);