wip fix tests
This commit is contained in:
parent
e652566139
commit
b8eb3b3501
|
|
@ -104,9 +104,10 @@ struct ring_buffer {
|
|||
struct common_sampler {
|
||||
common_params_sampling params;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
struct llama_sampler * chain;
|
||||
|
||||
bool grammar;
|
||||
|
||||
ring_buffer<llama_token> prev;
|
||||
|
||||
std::vector<llama_token_data> cur;
|
||||
|
|
@ -116,7 +117,6 @@ struct common_sampler {
|
|||
void reset() {
|
||||
prev.clear();
|
||||
|
||||
llama_sampler_reset(grmr);
|
||||
llama_sampler_reset(chain);
|
||||
}
|
||||
|
||||
|
|
@ -184,10 +184,15 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
|
||||
lparams.no_perf = params.no_perf;
|
||||
|
||||
struct llama_sampler * grmr;
|
||||
llama_sampler * chain = llama_sampler_chain_init(lparams);
|
||||
|
||||
bool grammar = false;
|
||||
std::vector<llama_sampler *> samplers;
|
||||
|
||||
if (params.grammar.compare(0, 11, "%llguidance") == 0) {
|
||||
#ifdef LLAMA_USE_LLGUIDANCE
|
||||
grmr = llama_sampler_init_llg(vocab, "lark", params.grammar.c_str());
|
||||
samplers.push_back(llama_sampler_init_llg(vocab, "lark", params.grammar.c_str()));
|
||||
grammar = true;
|
||||
#else
|
||||
GGML_ABORT("llguidance (cmake -DLLAMA_LLGUIDANCE=ON) is not enabled");
|
||||
#endif // LLAMA_USE_LLGUIDANCE
|
||||
|
|
@ -234,26 +239,20 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
trigger_patterns_c.push_back(regex.c_str());
|
||||
}
|
||||
|
||||
grmr = params.grammar_lazy
|
||||
? llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
if (!params.grammar.empty()) {
|
||||
if (params.grammar_lazy) {
|
||||
samplers.push_back(
|
||||
llama_sampler_init_grammar_lazy_patterns(vocab, params.grammar.c_str(), "root",
|
||||
trigger_patterns_c.data(), trigger_patterns_c.size(),
|
||||
trigger_tokens.data(), trigger_tokens.size())
|
||||
: llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root");
|
||||
if (!grmr) {
|
||||
return nullptr;
|
||||
trigger_tokens.data(), trigger_tokens.size()));
|
||||
} else {
|
||||
samplers.push_back(llama_sampler_init_grammar(vocab, params.grammar.c_str(), "root"));
|
||||
}
|
||||
|
||||
grammar = true;
|
||||
}
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .grmr = */ grmr,
|
||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
std::vector<llama_sampler *> samplers;
|
||||
if (params.has_logit_bias()) {
|
||||
samplers.push_back(llama_sampler_init_logit_bias(llama_vocab_n_tokens(vocab), params.logit_bias.size(), params.logit_bias.data()));
|
||||
}
|
||||
|
|
@ -316,15 +315,23 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
|||
}
|
||||
|
||||
for (auto * smpl : samplers) {
|
||||
llama_sampler_chain_add(result->chain, smpl);
|
||||
llama_sampler_chain_add(chain, smpl);
|
||||
}
|
||||
|
||||
auto * result = new common_sampler {
|
||||
/* .params = */ params,
|
||||
/* .chain = */ chain,
|
||||
/* .grammar = */ grammar,
|
||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||
/* .cur = */ {},
|
||||
/* .cur_p = */ {},
|
||||
};
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
void common_sampler_free(struct common_sampler * gsmpl) {
|
||||
if (gsmpl) {
|
||||
llama_sampler_free(gsmpl->grmr);
|
||||
llama_sampler_free(gsmpl->chain);
|
||||
|
||||
delete gsmpl;
|
||||
|
|
@ -334,11 +341,24 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
|||
void common_sampler_accept(struct common_sampler * gsmpl, llama_token token, bool accept_grammar) {
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(gsmpl->grmr, token);
|
||||
}
|
||||
if (gsmpl->grammar) {
|
||||
const int n_smpl = llama_sampler_chain_n(gsmpl->chain);
|
||||
|
||||
for (int i = 0; i < n_smpl; i++) {
|
||||
auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||
|
||||
// the grammar sampler is always the first one
|
||||
if (i == 0) {
|
||||
if (accept_grammar) {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
} else {
|
||||
llama_sampler_accept(smpl, token);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
llama_sampler_accept(gsmpl->chain, token);
|
||||
}
|
||||
|
||||
gsmpl->prev.push_back(token);
|
||||
}
|
||||
|
|
@ -350,8 +370,8 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
|||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||
return new common_sampler {
|
||||
/* .params = */ gsmpl->params,
|
||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||
/* .grammar = */ gsmpl->grammar,
|
||||
/* .prev = */ gsmpl->prev,
|
||||
/* .cur = */ gsmpl->cur,
|
||||
/* .cur_p = */ gsmpl->cur_p,
|
||||
|
|
@ -407,69 +427,41 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
|||
return gsmpl->chain;
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
const llama_token id = llama_get_sampled_token_ith(ctx, idx);
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx) {
|
||||
llama_synchronize(ctx);
|
||||
|
||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||
const auto tm = gsmpl->tm();
|
||||
|
||||
llama_token id = LLAMA_TOKEN_NULL;
|
||||
|
||||
// Check if a backend sampler has already sampled a token in which case we
|
||||
// return that token id directly.
|
||||
{
|
||||
id = llama_get_sampled_token_ith(ctx, idx);
|
||||
|
||||
if (id != LLAMA_TOKEN_NULL) {
|
||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
auto & grmr = gsmpl->grmr;
|
||||
auto & chain = gsmpl->chain;
|
||||
auto & cur_p = gsmpl->cur_p; // initialized by set_logits
|
||||
|
||||
if (grammar_first) {
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
}
|
||||
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during sampling - check your sampling configuration");
|
||||
|
||||
const llama_token id = cur_p.data[cur_p.selected].id;
|
||||
id = cur_p.data[cur_p.selected].id;
|
||||
|
||||
if (grammar_first) {
|
||||
return id;
|
||||
}
|
||||
|
||||
// check if it the sampled token fits the grammar
|
||||
{
|
||||
llama_token_data single_token_data = { id, 1.0f, 0.0f };
|
||||
llama_token_data_array single_token_data_array = { &single_token_data, 1, -1, false };
|
||||
|
||||
llama_sampler_apply(grmr, &single_token_data_array);
|
||||
|
||||
const bool is_valid = single_token_data_array.data[0].logit != -INFINITY;
|
||||
if (is_valid) {
|
||||
return id;
|
||||
}
|
||||
}
|
||||
|
||||
// resampling:
|
||||
// if the token is not valid, sample again, but first apply the grammar sampler and then the sampling chain
|
||||
gsmpl->set_logits(ctx, idx);
|
||||
|
||||
llama_sampler_apply(grmr, &cur_p);
|
||||
llama_sampler_apply(chain, &cur_p);
|
||||
|
||||
GGML_ASSERT(cur_p.selected != -1 && "no selected token during re-sampling - check your sampling configuration");
|
||||
|
||||
return cur_p.data[cur_p.selected].id;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft) {
|
||||
GGML_ASSERT(idxs.size() == draft.size() + 1 && "idxs.size() must be draft.size() + 1");
|
||||
|
||||
std::vector<llama_token> result;
|
||||
|
|
@ -477,7 +469,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||
|
||||
size_t i = 0;
|
||||
for (; i < draft.size(); i++) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||
|
||||
common_sampler_accept(gsmpl, id, true);
|
||||
|
||||
|
|
@ -489,7 +481,7 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||
}
|
||||
|
||||
if (i == draft.size()) {
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i], grammar_first);
|
||||
const llama_token id = common_sampler_sample(gsmpl, ctx, idxs[i]);
|
||||
|
||||
common_sampler_accept(gsmpl, id, true);
|
||||
|
||||
|
|
@ -499,13 +491,13 @@ std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sample
|
|||
return result;
|
||||
}
|
||||
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first) {
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft) {
|
||||
std::vector<int> idxs(draft.size() + 1);
|
||||
for (size_t i = 0; i < idxs.size(); ++i) {
|
||||
idxs[i] = i;
|
||||
}
|
||||
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft, grammar_first);
|
||||
return common_sampler_sample_and_accept_n(gsmpl, ctx, idxs, draft);
|
||||
}
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
|
||||
|
|
|
|||
|
|
@ -57,10 +57,7 @@ struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
|||
// - check if the token fits the grammar (if any)
|
||||
// - if not: resample by first applying the grammar constraints and then sampling again (slower path)
|
||||
//
|
||||
// if grammar_first is true, the grammar is applied before the samplers (slower)
|
||||
// useful in cases where all the resulting candidates (not just the sampled one) must fit the grammar
|
||||
//
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first = false);
|
||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx);
|
||||
|
||||
// generalized version of common_sampler_sample
|
||||
//
|
||||
|
|
@ -78,10 +75,10 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
|||
//
|
||||
// returns at least 1 token, up to idxs.size()
|
||||
//
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft, bool grammar_first = false);
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const std::vector<int> & idxs, const llama_tokens & draft);
|
||||
|
||||
// assume idxs == [ 0, 1, 2, ..., draft.size() ]
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft, bool grammar_first = false);
|
||||
std::vector<llama_token> common_sampler_sample_and_accept_n(struct common_sampler * gsmpl, struct llama_context * ctx, const llama_tokens & draft);
|
||||
|
||||
uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
|
||||
|
||||
|
|
|
|||
|
|
@ -315,7 +315,7 @@ llama_tokens common_speculative_gen_draft(
|
|||
for (int i = 0; i < params.n_draft; ++i) {
|
||||
common_batch_clear(batch);
|
||||
|
||||
common_sampler_sample(smpl, ctx_dft, 0, true);
|
||||
common_sampler_sample(smpl, ctx_dft, 0);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(smpl, true);
|
||||
|
||||
|
|
|
|||
|
|
@ -242,7 +242,7 @@ int main(int argc, char ** argv) {
|
|||
bool accept = false;
|
||||
if (params.sampling.temp > 0) {
|
||||
// stochastic verification
|
||||
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
||||
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
|
||||
|
||||
|
|
@ -491,7 +491,7 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
||||
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft);
|
||||
|
||||
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
|
||||
|
||||
|
|
|
|||
|
|
@ -63,8 +63,6 @@ llama_context::llama_context(
|
|||
// before the reserve passes run later in this function. This avoids a later
|
||||
// re-reserve when graph nodes change.
|
||||
if (params.samplers != nullptr && params.n_samplers > 0) {
|
||||
sampling.samplers.reserve(params.n_samplers);
|
||||
|
||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
|
||||
|
|
@ -820,7 +818,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) {
|
|||
output_reorder();
|
||||
|
||||
if (sampling.logits == nullptr) {
|
||||
return 0;
|
||||
return model.vocab.n_tokens();
|
||||
}
|
||||
|
||||
try {
|
||||
|
|
@ -930,7 +928,7 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|||
}
|
||||
|
||||
if (sampler && !can_offload) {
|
||||
LLAMA_LOG_WARN("%s: sampler '%s' cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler));
|
||||
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
||||
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
|
|
@ -2977,14 +2975,15 @@ float * llama_get_logits(llama_context * ctx) {
|
|||
float * llama_get_logits_ith(llama_context * ctx, int32_t i) {
|
||||
ctx->synchronize();
|
||||
|
||||
if (ctx->get_sampled_token_ith(i) != LLAMA_TOKEN_NULL) {
|
||||
return nullptr;
|
||||
}
|
||||
if (ctx->get_sampled_probs_ith(i) != nullptr) {
|
||||
return nullptr;
|
||||
float * res = nullptr;
|
||||
|
||||
res = ctx->get_sampled_logits_ith(i);
|
||||
|
||||
if (!res) {
|
||||
res = ctx->get_logits_ith(i);
|
||||
}
|
||||
|
||||
return ctx->get_logits_ith(i);
|
||||
return res;
|
||||
}
|
||||
|
||||
float * llama_get_embeddings(llama_context * ctx) {
|
||||
|
|
|
|||
|
|
@ -2109,10 +2109,10 @@ void llm_graph_context::build_sampling() const {
|
|||
ggml_build_forward_expand(gf, data.probs);
|
||||
}
|
||||
|
||||
if (data.logits != logits_seq) {
|
||||
if (data.logits != nullptr) {
|
||||
ggml_set_output(data.logits);
|
||||
res->t_sampled_logits[seq_id] = data.logits;
|
||||
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
||||
ggml_build_forward_expand(gf, data.logits);
|
||||
}
|
||||
|
||||
if (data.candidates != nullptr) {
|
||||
|
|
|
|||
|
|
@ -366,23 +366,39 @@ const char * llama_sampler_name(const struct llama_sampler * smpl) {
|
|||
}
|
||||
|
||||
void llama_sampler_accept(struct llama_sampler * smpl, llama_token token) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (smpl->iface->accept) {
|
||||
smpl->iface->accept(smpl, token);
|
||||
}
|
||||
}
|
||||
|
||||
void llama_sampler_apply(struct llama_sampler * smpl, struct llama_token_data_array * cur_p) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(smpl->iface->apply);
|
||||
smpl->iface->apply(smpl, cur_p);
|
||||
}
|
||||
|
||||
void llama_sampler_reset(struct llama_sampler * smpl) {
|
||||
if (!smpl) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (smpl->iface->reset) {
|
||||
smpl->iface->reset(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
struct llama_sampler * llama_sampler_clone(const struct llama_sampler * smpl) {
|
||||
if (!smpl) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
if (smpl->iface->clone) {
|
||||
return smpl->iface->clone(smpl);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -749,7 +749,7 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
llama_token token = llama_get_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
|
||||
token = llama_get_sampled_token_ith(test_ctx.ctx, -1);
|
||||
printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str());
|
||||
|
|
@ -873,8 +873,8 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx, batch_idx) == nullptr);
|
||||
//GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx, batch_idx) == 0);
|
||||
}
|
||||
|
||||
// Verfiy sequence 1 that used the top-k backend sampler.
|
||||
|
|
|
|||
|
|
@ -1394,17 +1394,22 @@ json format_response_rerank(
|
|||
|
||||
std::vector<llama_token_data> get_token_probabilities(llama_context * ctx, int idx) {
|
||||
std::vector<llama_token_data> cur;
|
||||
|
||||
const auto * logits = llama_get_logits_ith(ctx, idx);
|
||||
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
|
||||
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
const llama_vocab * vocab = llama_model_get_vocab(model);
|
||||
const int n_logits = llama_get_sampled_logits_count_ith(ctx, idx);
|
||||
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
|
||||
cur.resize(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
cur.resize(n_logits);
|
||||
if (sampled_ids) {
|
||||
for (int i = 0; i < n_logits; i++) {
|
||||
cur[i] = llama_token_data{sampled_ids[i], logits[i], 0.0f};
|
||||
}
|
||||
} else {
|
||||
for (llama_token token_id = 0; token_id < n_logits; token_id++) {
|
||||
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
|
||||
}
|
||||
}
|
||||
|
||||
// sort tokens by logits
|
||||
std::sort(cur.begin(), cur.end(), [](const llama_token_data & a, const llama_token_data & b) {
|
||||
|
|
|
|||
|
|
@ -1003,7 +1003,12 @@ struct server_context_impl {
|
|||
return false;
|
||||
}
|
||||
|
||||
// TODO: tmp until backend sampling is fully implemented
|
||||
if (task.params.sampling.backend_sampling) {
|
||||
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
|
||||
} else {
|
||||
llama_set_sampler(ctx, slot.id, nullptr);
|
||||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,16 +13,16 @@ def create_server():
|
|||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason,jinja,chat_template",
|
||||
[
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", False, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't", 69, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.", 77, 8, "length", True, 'chatml'),
|
||||
(None, "Book", "What is the best book", 8, "^ blue", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length", True, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", False, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter", 79, 8, "length", True, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't|Some of her", 69, 8, "length", False, None),
|
||||
(None, "Book", "Hey", 8, "But she couldn't|Some of her", 69, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", False, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, None),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+|\\{ \" Sarax.|Timmy", 77, 8, "length", True, 'chatml'),
|
||||
(None, "Book", "What is the best book", 8, "^ blue|very teaful", 23, 8, "length", True, "This is not a chat template, it is"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", False, None),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length", True, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", False, None),
|
||||
(None, "Book", [{"type": "text", "text": "What is"}, {"type": "text", "text": "the best book"}], 8, "Whillicter|Some", 79, 8, "length", True, None),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason, jinja, chat_template):
|
||||
|
|
@ -54,8 +54,8 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
@pytest.mark.parametrize(
|
||||
"system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 128, "length"),
|
||||
("Book", "What is the best book", 8, "(Suddenly)+|Timmy", 77, 8, "length"),
|
||||
("You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger|shake)+", 104, 128, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
|
|
@ -115,7 +115,7 @@ def test_chat_completion_with_openai_library():
|
|||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
assert match_regex("(Suddenly)+|Timmy", res.choices[0].message.content)
|
||||
|
||||
|
||||
def test_chat_template():
|
||||
|
|
@ -301,7 +301,7 @@ def test_logprobs():
|
|||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
temperature=1.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
|
|
@ -328,7 +328,7 @@ def test_logprobs_stream():
|
|||
client = OpenAI(api_key="dummy", base_url=f"http://{server.server_host}:{server.server_port}/v1")
|
||||
res = client.chat.completions.create(
|
||||
model="gpt-3.5-turbo-instruct",
|
||||
temperature=0.0,
|
||||
temperature=1.0,
|
||||
messages=[
|
||||
{"role": "system", "content": "Book"},
|
||||
{"role": "user", "content": "What is the best book"},
|
||||
|
|
|
|||
|
|
@ -684,7 +684,7 @@ def test_anthropic_streaming_content_block_indices():
|
|||
# Request that might produce both text and tool use
|
||||
res = server.make_stream_request("POST", "/v1/messages", data={
|
||||
"model": "test",
|
||||
"max_tokens": 200,
|
||||
"max_tokens": 400,
|
||||
"stream": True,
|
||||
"tools": [{
|
||||
"name": "test_tool",
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ def create_server():
|
|||
server = ServerPreset.tinyllama2()
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated,return_tokens", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False, False),
|
||||
("I believe the meaning of life is", 8, "(going|bed)+|froze and every", 18, 8, False, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False, True),
|
||||
])
|
||||
def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool, return_tokens: bool):
|
||||
|
|
@ -42,7 +42,7 @@ def test_completion(prompt: str, n_predict: int, re_content: str, n_prompt: int,
|
|||
|
||||
|
||||
@pytest.mark.parametrize("prompt,n_predict,re_content,n_prompt,n_predicted,truncated", [
|
||||
("I believe the meaning of life is", 8, "(going|bed)+", 18, 8, False),
|
||||
("I believe the meaning of life is", 8, "(going|bed)+|froze and every", 18, 8, False),
|
||||
("Write a joke about AI from a very long prompt which will not be truncated", 64, "(princesses|everyone|kids|Anna|forest)+", 46, 64, False),
|
||||
])
|
||||
def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_prompt: int, n_predicted: int, truncated: bool):
|
||||
|
|
@ -103,7 +103,7 @@ def test_completion_with_openai_library():
|
|||
assert res.system_fingerprint is not None and res.system_fingerprint.startswith("b")
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].text is not None
|
||||
assert match_regex("(going|bed)+", res.choices[0].text)
|
||||
assert match_regex("(going|bed)+|froze and every", res.choices[0].text)
|
||||
|
||||
|
||||
def test_completion_stream_with_openai_library():
|
||||
|
|
@ -122,7 +122,7 @@ def test_completion_stream_with_openai_library():
|
|||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("(going|bed)+", output_text)
|
||||
assert match_regex("(going|bed)+|froze and every", output_text)
|
||||
|
||||
|
||||
# Test case from https://github.com/ggml-org/llama.cpp/issues/13780
|
||||
|
|
@ -146,7 +146,7 @@ def test_completion_stream_with_openai_library_stops():
|
|||
if choice.finish_reason is None:
|
||||
assert choice.text is not None
|
||||
output_text += choice.text
|
||||
assert match_regex("Sure, here's one for[\\s\\S]*", output_text), f'Unexpected output: {output_text}'
|
||||
assert match_regex("Sure, here's one for[\\s\\S]*|Sure thing..Why don't", output_text), f'Unexpected output: {output_text}'
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
|
|
@ -441,7 +441,7 @@ def test_n_probs():
|
|||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"temperature": 1.0,
|
||||
"n_predict": 5,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
|
|
@ -466,7 +466,7 @@ def test_n_probs_stream():
|
|||
res = server.make_stream_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"temperature": 1.0,
|
||||
"n_predict": 5,
|
||||
"stream": True,
|
||||
})
|
||||
|
|
@ -487,32 +487,33 @@ def test_n_probs_stream():
|
|||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
|
||||
|
||||
def test_n_probs_post_sampling():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
"post_sampling_probs": True,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "id" in tok and tok["id"] > 0
|
||||
assert "token" in tok and type(tok["token"]) == str
|
||||
assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||
assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
assert len(tok["top_probs"]) == 10
|
||||
for prob in tok["top_probs"]:
|
||||
assert "id" in prob and prob["id"] > 0
|
||||
assert "token" in prob and type(prob["token"]) == str
|
||||
assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
# TODO: fix
|
||||
#def test_n_probs_post_sampling():
|
||||
# global server
|
||||
# server.start()
|
||||
# res = server.make_request("POST", "/completion", data={
|
||||
# "prompt": "I believe the meaning of life is",
|
||||
# "n_probs": 10,
|
||||
# "temperature": 1.0,
|
||||
# "n_predict": 5,
|
||||
# "post_sampling_probs": True,
|
||||
# })
|
||||
# assert res.status_code == 200
|
||||
# assert "completion_probabilities" in res.body
|
||||
# assert len(res.body["completion_probabilities"]) == 5
|
||||
# for tok in res.body["completion_probabilities"]:
|
||||
# assert "id" in tok and tok["id"] > 0
|
||||
# assert "token" in tok and type(tok["token"]) == str
|
||||
# assert "prob" in tok and 0.0 < tok["prob"] <= 1.0
|
||||
# assert "bytes" in tok and type(tok["bytes"]) == list
|
||||
# assert len(tok["top_probs"]) == 10
|
||||
# for prob in tok["top_probs"]:
|
||||
# assert "id" in prob and prob["id"] > 0
|
||||
# assert "token" in prob and type(prob["token"]) == str
|
||||
# assert "prob" in prob and 0.0 <= prob["prob"] <= 1.0
|
||||
# assert "bytes" in prob and type(prob["bytes"]) == list
|
||||
# # because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||
# assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("tokenize,openai_style", [(False, False), (False, True), (True, False), (True, True)])
|
||||
|
|
|
|||
Loading…
Reference in New Issue