sampling : check backend support during init
This commit is contained in:
parent
1bde70785d
commit
6958d41366
|
|
@ -1098,8 +1098,7 @@ common_init_result::common_init_result(common_params & params) :
|
||||||
|
|
||||||
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
|
||||||
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
|
||||||
llama_sampler * backend_chain = common_sampler_chain_backend(pimpl->samplers[i].get());
|
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||||
pimpl->samplers_seq_config[i] = { i, backend_chain };
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||||
|
|
|
||||||
|
|
@ -106,7 +106,6 @@ struct common_sampler {
|
||||||
|
|
||||||
struct llama_sampler * grmr;
|
struct llama_sampler * grmr;
|
||||||
struct llama_sampler * chain;
|
struct llama_sampler * chain;
|
||||||
struct llama_sampler * chain_backend;
|
|
||||||
|
|
||||||
ring_buffer<llama_token> prev;
|
ring_buffer<llama_token> prev;
|
||||||
|
|
||||||
|
|
@ -119,7 +118,6 @@ struct common_sampler {
|
||||||
|
|
||||||
llama_sampler_reset(grmr);
|
llama_sampler_reset(grmr);
|
||||||
llama_sampler_reset(chain);
|
llama_sampler_reset(chain);
|
||||||
llama_sampler_reset(chain_backend);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void set_logits(struct llama_context * ctx, int idx) {
|
void set_logits(struct llama_context * ctx, int idx) {
|
||||||
|
|
@ -247,13 +245,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * result = new common_sampler {
|
auto * result = new common_sampler {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
/* .grmr = */ grmr,
|
/* .grmr = */ grmr,
|
||||||
/* .chain = */ llama_sampler_chain_init(lparams),
|
/* .chain = */ llama_sampler_chain_init(lparams),
|
||||||
/* .chain_backend = */ llama_sampler_chain_init(lparams),
|
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
||||||
/* .prev = */ ring_buffer<llama_token>(std::max(32, params.n_prev)),
|
/* .cur = */ {},
|
||||||
/* .cur = */ {},
|
/* .cur_p = */ {},
|
||||||
/* .cur_p = */ {},
|
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<llama_sampler *> samplers;
|
std::vector<llama_sampler *> samplers;
|
||||||
|
|
@ -318,15 +315,8 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co
|
||||||
GGML_ASSERT(false && "unknown mirostat version");
|
GGML_ASSERT(false && "unknown mirostat version");
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_backend = params.backend_sampling;
|
|
||||||
|
|
||||||
// split in two chains: backend -> CPU
|
|
||||||
for (auto * smpl : samplers) {
|
for (auto * smpl : samplers) {
|
||||||
if (!smpl->iface->backend_apply) {
|
llama_sampler_chain_add(result->chain, smpl);
|
||||||
is_backend = false;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_sampler_chain_add(is_backend ? result->chain_backend : result->chain, smpl);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -336,7 +326,6 @@ void common_sampler_free(struct common_sampler * gsmpl) {
|
||||||
if (gsmpl) {
|
if (gsmpl) {
|
||||||
llama_sampler_free(gsmpl->grmr);
|
llama_sampler_free(gsmpl->grmr);
|
||||||
llama_sampler_free(gsmpl->chain);
|
llama_sampler_free(gsmpl->chain);
|
||||||
llama_sampler_free(gsmpl->chain_backend);
|
|
||||||
|
|
||||||
delete gsmpl;
|
delete gsmpl;
|
||||||
}
|
}
|
||||||
|
|
@ -360,13 +349,12 @@ void common_sampler_reset(struct common_sampler * gsmpl) {
|
||||||
|
|
||||||
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
struct common_sampler * common_sampler_clone(common_sampler * gsmpl) {
|
||||||
return new common_sampler {
|
return new common_sampler {
|
||||||
/* .params = */ gsmpl->params,
|
/* .params = */ gsmpl->params,
|
||||||
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
/* .grmr = */ llama_sampler_clone(gsmpl->grmr),
|
||||||
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
/* .chain = */ llama_sampler_clone(gsmpl->chain),
|
||||||
/* .chain_backend = */ llama_sampler_clone(gsmpl->chain_backend),
|
/* .prev = */ gsmpl->prev,
|
||||||
/* .prev = */ gsmpl->prev,
|
/* .cur = */ gsmpl->cur,
|
||||||
/* .cur = */ gsmpl->cur,
|
/* .cur_p = */ gsmpl->cur_p,
|
||||||
/* .cur_p = */ gsmpl->cur_p,
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -415,8 +403,8 @@ void common_perf_print(const struct llama_context * ctx, const struct common_sam
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl) {
|
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl) {
|
||||||
return gsmpl->chain_backend;
|
return gsmpl->chain;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_context * ctx, int idx, bool grammar_first) {
|
||||||
|
|
@ -424,11 +412,13 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
|
||||||
// return that token id directly.
|
// return that token id directly.
|
||||||
{
|
{
|
||||||
const llama_token id = llama_get_sampled_token_ith(ctx, idx);
|
const llama_token id = llama_get_sampled_token_ith(ctx, idx);
|
||||||
|
|
||||||
if (id != LLAMA_TOKEN_NULL) {
|
if (id != LLAMA_TOKEN_NULL) {
|
||||||
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
|
||||||
return id;
|
return id;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_synchronize(ctx);
|
llama_synchronize(ctx);
|
||||||
|
|
||||||
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
// start measuring sampling time after the llama_context synchronization in order to not measure any ongoing async operations
|
||||||
|
|
@ -556,16 +546,12 @@ llama_token common_sampler_last(const struct common_sampler * gsmpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
std::string common_sampler_print(const struct common_sampler * gsmpl) {
|
||||||
std::string result = llama_sampler_chain_n(gsmpl->chain_backend) > 0 ? "*logits " : "logits ";
|
std::string result = "logits ";
|
||||||
|
|
||||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain_backend); i++) {
|
|
||||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain_backend, i);
|
|
||||||
result += std::string("-> *") + llama_sampler_name(smpl) + " ";
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
for (int i = 0; i < llama_sampler_chain_n(gsmpl->chain); i++) {
|
||||||
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
const auto * smpl = llama_sampler_chain_get(gsmpl->chain, i);
|
||||||
result += std::string("-> ") + llama_sampler_name(smpl) + " ";
|
result += std::string("-> ");
|
||||||
|
result += std::string(llama_sampler_name(smpl)) + " ";
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
|
||||||
// arguments can be nullptr to skip printing
|
// arguments can be nullptr to skip printing
|
||||||
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
struct llama_sampler * common_sampler_chain_backend(const struct common_sampler * gsmpl);
|
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
|
||||||
|
|
||||||
// extended sampling implementation:
|
// extended sampling implementation:
|
||||||
//
|
//
|
||||||
|
|
|
||||||
|
|
@ -369,7 +369,8 @@ extern "C" {
|
||||||
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
// try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix
|
||||||
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
// ref: https://github.com/ggml-org/llama.cpp/pull/14363
|
||||||
|
|
||||||
// backend sampler chain configuration (does not keep a reference, so make sure the caller keeps the samplers alive)
|
// backend sampler chain configuration (make sure the caller keeps the sampler chains alive)
|
||||||
|
// note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init)
|
||||||
struct llama_sampler_seq_config * samplers;
|
struct llama_sampler_seq_config * samplers;
|
||||||
size_t n_samplers;
|
size_t n_samplers;
|
||||||
};
|
};
|
||||||
|
|
@ -1193,21 +1194,27 @@ extern "C" {
|
||||||
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||||
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL
|
||||||
|
|
||||||
// backend sampling interface
|
// backend sampling interface:
|
||||||
void (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
|
|
||||||
|
|
||||||
|
// return true if the backend supports all ops needed by the sampler
|
||||||
|
// note: call once per sampler
|
||||||
|
bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft);
|
||||||
|
|
||||||
|
// call after .backend_accept()
|
||||||
void (*backend_accept)(
|
void (*backend_accept)(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct ggml_tensor * selected_token);
|
struct ggml_tensor * selected_token);
|
||||||
|
|
||||||
|
// call after .backend_init()
|
||||||
void (*backend_apply)(
|
void (*backend_apply)(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data);
|
struct llama_sampler_data * data);
|
||||||
|
|
||||||
|
// call before .backend_apply()
|
||||||
void (*backend_set_input)(struct llama_sampler * smpl);
|
void (*backend_set_input)(struct llama_sampler * smpl);
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,8 @@ llama_context::llama_context(
|
||||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||||
const auto & config = params.samplers[i];
|
const auto & config = params.samplers[i];
|
||||||
|
|
||||||
|
// TODO: assert this is a llama_sampler_chain instance
|
||||||
|
|
||||||
if (set_sampler(config.seq_id, config.sampler)) {
|
if (set_sampler(config.seq_id, config.sampler)) {
|
||||||
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
const int n_samplers = llama_sampler_chain_n(config.sampler);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -504,11 +504,13 @@ static void llama_sampler_empty_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_empty *) smpl->ctx;
|
delete (llama_sampler_empty *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_empty_backend_init(
|
static bool llama_sampler_empty_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
GGML_UNUSED(smpl);
|
GGML_UNUSED(smpl);
|
||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_empty_backend_accept(
|
static void llama_sampler_empty_backend_accept(
|
||||||
|
|
@ -559,6 +561,43 @@ struct llama_sampler * llama_sampler_init_empty(const char * name) {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// common backend sampler functionality
|
||||||
|
//
|
||||||
|
// +name : means that the sampler is support and will run on the backend
|
||||||
|
// -name : means that a ggml operator is not supported by the backend
|
||||||
|
//
|
||||||
|
struct llama_sampler_backend {
|
||||||
|
llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {}
|
||||||
|
|
||||||
|
const char * get_name() {
|
||||||
|
if (!is_init) {
|
||||||
|
return name.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (support) {
|
||||||
|
name_ext = "+" + name;
|
||||||
|
} else {
|
||||||
|
name_ext = "-" + name;
|
||||||
|
}
|
||||||
|
|
||||||
|
return name_ext.c_str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void init(bool support) {
|
||||||
|
GGML_ASSERT(this->is_init == false);
|
||||||
|
|
||||||
|
this->is_init = true;
|
||||||
|
this->support = support;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::string name;
|
||||||
|
std::string name_ext;
|
||||||
|
|
||||||
|
bool is_init;
|
||||||
|
bool support;
|
||||||
|
};
|
||||||
|
|
||||||
// sampler chain
|
// sampler chain
|
||||||
|
|
||||||
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) {
|
||||||
|
|
@ -570,8 +609,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token
|
||||||
|
|
||||||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
for (auto & smpl : chain->samplers) {
|
||||||
llama_sampler_accept(smpl, token);
|
llama_sampler_accept(smpl.ptr, token);
|
||||||
}
|
}
|
||||||
|
|
||||||
chain->n_sample++;
|
chain->n_sample++;
|
||||||
|
|
@ -582,20 +621,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d
|
||||||
|
|
||||||
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
time_meas tm(chain->t_sample_us, chain->params.no_perf);
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
bool is_backend = chain->is_init;
|
||||||
if (smpl->iface->apply == nullptr) {
|
|
||||||
|
for (auto & smpl : chain->samplers) {
|
||||||
|
if (is_backend && smpl.is_backend) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_sampler_apply(smpl, cur_p);
|
is_backend = false;
|
||||||
|
|
||||||
|
if (smpl.ptr->iface->apply == nullptr) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_sampler_apply(smpl.ptr, cur_p);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
static void llama_sampler_chain_reset(struct llama_sampler * smpl) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
for (auto & smpl : chain->samplers) {
|
||||||
llama_sampler_reset(smpl);
|
llama_sampler_reset(smpl.ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -604,8 +651,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
||||||
|
|
||||||
auto * result = llama_sampler_chain_init(chain_src->params);
|
auto * result = llama_sampler_chain_init(chain_src->params);
|
||||||
|
|
||||||
for (auto * smpl : chain_src->samplers) {
|
for (const auto & smpl : chain_src->samplers) {
|
||||||
llama_sampler_chain_add(result, llama_sampler_clone(smpl));
|
llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr));
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -614,23 +661,44 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl
|
||||||
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
static void llama_sampler_chain_free(struct llama_sampler * smpl) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
for (auto & smpl : chain->samplers) {
|
||||||
llama_sampler_free(smpl);
|
llama_sampler_free(smpl.ptr);
|
||||||
}
|
}
|
||||||
|
|
||||||
delete chain;
|
delete chain;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_chain_backend_init(
|
static bool llama_sampler_chain_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice");
|
||||||
if (smpl->iface->backend_init) {
|
|
||||||
smpl->iface->backend_init(smpl,buft);
|
chain->is_init = true;
|
||||||
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
for (auto & smpl : chain->samplers) {
|
||||||
|
bool res_cur = true;
|
||||||
|
|
||||||
|
// to be able to run a sampler on the backend, it has to:
|
||||||
|
// - have the .backend_init() API implemented
|
||||||
|
// - return true during .backend_init()
|
||||||
|
if (smpl.ptr->iface->backend_init) {
|
||||||
|
if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) {
|
||||||
|
res_cur = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
res_cur = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
smpl.is_backend = res_cur;
|
||||||
|
|
||||||
|
res = res && res_cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_chain_backend_accept(
|
static void llama_sampler_chain_backend_accept(
|
||||||
|
|
@ -640,9 +708,13 @@ static void llama_sampler_chain_backend_accept(
|
||||||
struct ggml_tensor * selected_token) {
|
struct ggml_tensor * selected_token) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
for (auto & smpl : chain->samplers) {
|
||||||
if (smpl->iface->backend_accept) {
|
if (!smpl.is_backend) {
|
||||||
smpl->iface->backend_accept(smpl, ctx, gf, selected_token);
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl.ptr->iface->backend_accept) {
|
||||||
|
smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -654,9 +726,15 @@ static void llama_sampler_chain_backend_apply(
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called");
|
||||||
if (smpl->iface->backend_apply) {
|
|
||||||
smpl->iface->backend_apply(smpl, ctx, gf, data);
|
for (auto & smpl : chain->samplers) {
|
||||||
|
if (!smpl.is_backend) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl.ptr->iface->backend_apply) {
|
||||||
|
smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -664,9 +742,13 @@ static void llama_sampler_chain_backend_apply(
|
||||||
static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
|
static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) {
|
||||||
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
auto * chain = (llama_sampler_chain *) smpl->ctx;
|
||||||
|
|
||||||
for (auto * smpl : chain->samplers) {
|
for (auto & smpl : chain->samplers) {
|
||||||
if (smpl->iface->backend_set_input) {
|
if (!smpl.is_backend) {
|
||||||
smpl->iface->backend_set_input(smpl);
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (smpl.ptr->iface->backend_set_input) {
|
||||||
|
smpl.ptr->iface->backend_set_input(smpl.ptr);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -689,6 +771,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
||||||
/* .iface = */ &llama_sampler_chain_i,
|
/* .iface = */ &llama_sampler_chain_i,
|
||||||
/* .ctx = */ new llama_sampler_chain {
|
/* .ctx = */ new llama_sampler_chain {
|
||||||
/* .params = */ params,
|
/* .params = */ params,
|
||||||
|
/* .is_init = */ false,
|
||||||
/* .samplers = */ {},
|
/* .samplers = */ {},
|
||||||
/* .t_sample_us = */ 0,
|
/* .t_sample_us = */ 0,
|
||||||
/* .n_sample = */ 0,
|
/* .n_sample = */ 0,
|
||||||
|
|
@ -698,7 +781,10 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param
|
||||||
|
|
||||||
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) {
|
||||||
auto * p = (llama_sampler_chain *) chain->ctx;
|
auto * p = (llama_sampler_chain *) chain->ctx;
|
||||||
p->samplers.push_back(smpl);
|
p->samplers.push_back({
|
||||||
|
/* .is_backend = */ false,
|
||||||
|
/* .ptr = */ smpl,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) {
|
||||||
|
|
@ -708,7 +794,7 @@ struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chai
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
return p->samplers[i];
|
return p->samplers[i].ptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) {
|
||||||
|
|
@ -718,7 +804,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain,
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * result = p->samplers[i];
|
auto * result = p->samplers[i].ptr;
|
||||||
p->samplers.erase(p->samplers.begin() + i);
|
p->samplers.erase(p->samplers.begin() + i);
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
|
|
@ -749,6 +835,15 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_sampler_greedy_backend_init(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
ggml_backend_buffer_type_t buft) {
|
||||||
|
GGML_UNUSED(smpl);
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_greedy_backend_apply(
|
static void llama_sampler_greedy_backend_apply(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
|
|
@ -768,7 +863,7 @@ static struct llama_sampler_i llama_sampler_greedy_i = {
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ nullptr,
|
/* .clone = */ nullptr,
|
||||||
/* .free = */ nullptr,
|
/* .free = */ nullptr,
|
||||||
/* .backend_init = */ nullptr,
|
/* .backend_init = */ llama_sampler_greedy_backend_init,
|
||||||
/* .backend_accept = */ nullptr,
|
/* .backend_accept = */ nullptr,
|
||||||
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
|
/* .backend_apply = */ llama_sampler_greedy_backend_apply,
|
||||||
/* .backend_set_input = */ nullptr,
|
/* .backend_set_input = */ nullptr,
|
||||||
|
|
@ -783,23 +878,22 @@ struct llama_sampler * llama_sampler_init_greedy() {
|
||||||
|
|
||||||
// dist
|
// dist
|
||||||
|
|
||||||
struct llama_sampler_dist {
|
struct llama_sampler_dist : public llama_sampler_backend {
|
||||||
const uint32_t seed;
|
const uint32_t seed;
|
||||||
uint32_t seed_cur;
|
uint32_t seed_cur;
|
||||||
|
|
||||||
std::mt19937 rng;
|
std::mt19937 rng;
|
||||||
|
|
||||||
// Only required for checking operation support and can be removed later.
|
// backend input
|
||||||
ggml_backend_dev_t device;
|
|
||||||
|
|
||||||
struct ggml_tensor * inp_uniform;
|
struct ggml_tensor * inp_uniform;
|
||||||
|
|
||||||
ggml_context_ptr inp_ctx;
|
ggml_context_ptr inp_ctx;
|
||||||
ggml_backend_buffer_ptr inp_buf;
|
ggml_backend_buffer_ptr inp_buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) {
|
||||||
return "dist";
|
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
||||||
|
return sctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -898,13 +992,61 @@ static void llama_sampler_dist_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_dist *) smpl->ctx;
|
delete (llama_sampler_dist *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
static bool llama_sampler_dist_backend_init(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
||||||
GGML_ASSERT(sctx->inp_uniform != nullptr);
|
|
||||||
|
|
||||||
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
bool res = true;
|
||||||
const float rnd = dist(sctx->rng);
|
|
||||||
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
// determine backend support
|
||||||
|
{
|
||||||
|
ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||||
|
/*.mem_buffer =*/ NULL,
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||||
|
if (!ctx_ptr) {
|
||||||
|
throw std::runtime_error(format("failed to create ggml context"));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_context * ctx = ctx_ptr.get();
|
||||||
|
|
||||||
|
ggml_tensor * probs = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
|
||||||
|
ggml_tensor * op = ggml_cumsum(ctx, probs);
|
||||||
|
|
||||||
|
auto * device = ggml_backend_buft_get_device(buft);
|
||||||
|
GGML_ASSERT(device);
|
||||||
|
|
||||||
|
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||||
|
res = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
sctx->init(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (res) {
|
||||||
|
ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ ggml_tensor_overhead(),
|
||||||
|
/*.mem_buffer =*/ nullptr,
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
sctx->inp_ctx.reset(ggml_init(params));
|
||||||
|
|
||||||
|
// Create the uniform random scalar input tensor. This will be set by
|
||||||
|
// llama_sampler_dist_backend_set_input after this graph is built.
|
||||||
|
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
|
||||||
|
ggml_set_name(sctx->inp_uniform, "uniform");
|
||||||
|
ggml_set_input(sctx->inp_uniform);
|
||||||
|
|
||||||
|
// Allocate all tensors from our context to the backend
|
||||||
|
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_dist_backend_apply(
|
static void llama_sampler_dist_backend_apply(
|
||||||
|
|
@ -919,10 +1061,6 @@ static void llama_sampler_dist_backend_apply(
|
||||||
ggml_set_name(probs, "dist_probs");
|
ggml_set_name(probs, "dist_probs");
|
||||||
|
|
||||||
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs);
|
||||||
if (sctx->device && !ggml_backend_dev_supports_op(sctx->device, cumsum)) {
|
|
||||||
fprintf(stderr, "Warning: backend does not support cumsum operation required for dist sampling\n");
|
|
||||||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
|
||||||
}
|
|
||||||
ggml_set_name(cumsum, "cumsum");
|
ggml_set_name(cumsum, "cumsum");
|
||||||
|
|
||||||
// The uniform tensor has a random value and we subtract this tensor with
|
// The uniform tensor has a random value and we subtract this tensor with
|
||||||
|
|
@ -965,29 +1103,13 @@ static void llama_sampler_dist_backend_apply(
|
||||||
data->sampled = sampled_token;
|
data->sampled = sampled_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_dist_backend_init(
|
static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) {
|
||||||
struct llama_sampler * smpl,
|
|
||||||
ggml_backend_buffer_type_t buft) {
|
|
||||||
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
auto * sctx = (llama_sampler_dist *) smpl->ctx;
|
||||||
|
GGML_ASSERT(sctx->inp_uniform != nullptr);
|
||||||
|
|
||||||
sctx->device = ggml_backend_buft_get_device(buft);
|
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
|
||||||
|
const float rnd = dist(sctx->rng);
|
||||||
ggml_init_params params = {
|
ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float));
|
||||||
/*.mem_size =*/ ggml_tensor_overhead(),
|
|
||||||
/*.mem_buffer =*/ nullptr,
|
|
||||||
/*.no_alloc =*/ true,
|
|
||||||
};
|
|
||||||
|
|
||||||
sctx->inp_ctx.reset(ggml_init(params));
|
|
||||||
|
|
||||||
// Create the uniform random scalar input tensor. This will be set by
|
|
||||||
// llama_sampler_dist_backend_set_input after this graph is built.
|
|
||||||
sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1);
|
|
||||||
ggml_set_name(sctx->inp_uniform, "uniform");
|
|
||||||
ggml_set_input(sctx->inp_uniform);
|
|
||||||
|
|
||||||
// Allocate all tensors from our context to the backend
|
|
||||||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_dist_i = {
|
static struct llama_sampler_i llama_sampler_dist_i = {
|
||||||
|
|
@ -1008,10 +1130,10 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_dist_i,
|
/* .iface = */ &llama_sampler_dist_i,
|
||||||
/* .ctx = */ new llama_sampler_dist {
|
/* .ctx = */ new llama_sampler_dist {
|
||||||
|
("dist"),
|
||||||
/* .seed = */ seed,
|
/* .seed = */ seed,
|
||||||
/* .seed_cur = */ seed_cur,
|
/* .seed_cur = */ seed_cur,
|
||||||
/* .rng = */ std::mt19937(seed_cur),
|
/* .rng = */ std::mt19937(seed_cur),
|
||||||
/* .device = */ nullptr,
|
|
||||||
/* .inp_uniform = */ nullptr,
|
/* .inp_uniform = */ nullptr,
|
||||||
/* .inp_ctx = */ nullptr,
|
/* .inp_ctx = */ nullptr,
|
||||||
/* .inp_buf = */ nullptr,
|
/* .inp_buf = */ nullptr,
|
||||||
|
|
@ -1021,15 +1143,13 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) {
|
||||||
|
|
||||||
// top-k
|
// top-k
|
||||||
|
|
||||||
struct llama_sampler_top_k {
|
struct llama_sampler_top_k : public llama_sampler_backend {
|
||||||
const int32_t k;
|
const int32_t k;
|
||||||
|
|
||||||
// Only required for checking operation support and can be removed later.
|
|
||||||
ggml_backend_dev_t device;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) {
|
||||||
return "top-k";
|
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
||||||
|
return sctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -1046,11 +1166,42 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_top_k *) smpl->ctx;
|
delete (llama_sampler_top_k *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_k_backend_init(
|
static bool llama_sampler_top_k_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
|
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
||||||
ctx_data->device = ggml_backend_buft_get_device(buft);
|
|
||||||
|
bool res = true;
|
||||||
|
|
||||||
|
// determine backend support
|
||||||
|
{
|
||||||
|
ggml_init_params params = {
|
||||||
|
/*.mem_size =*/ ggml_tensor_overhead()*8,
|
||||||
|
/*.mem_buffer =*/ NULL,
|
||||||
|
/*.no_alloc =*/ true,
|
||||||
|
};
|
||||||
|
|
||||||
|
ggml_context_ptr ctx_ptr { ggml_init(params) };
|
||||||
|
if (!ctx_ptr) {
|
||||||
|
throw std::runtime_error(format("failed to create ggml context"));
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_context * ctx = ctx_ptr.get();
|
||||||
|
|
||||||
|
ggml_tensor * logits = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 1024*1024);
|
||||||
|
ggml_tensor * op = ggml_top_k(ctx, logits, sctx->k);
|
||||||
|
|
||||||
|
auto * device = ggml_backend_buft_get_device(buft);
|
||||||
|
GGML_ASSERT(device);
|
||||||
|
|
||||||
|
if (!ggml_backend_dev_supports_op(device, op)) {
|
||||||
|
res = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
sctx->init(res);
|
||||||
|
}
|
||||||
|
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_k_backend_apply(
|
static void llama_sampler_top_k_backend_apply(
|
||||||
|
|
@ -1058,26 +1209,17 @@ static void llama_sampler_top_k_backend_apply(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
|
auto * sctx = (llama_sampler_top_k *) smpl->ctx;
|
||||||
|
|
||||||
auto * ctx_data = (llama_sampler_top_k *) smpl->ctx;
|
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k);
|
||||||
|
|
||||||
struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, ctx_data->k);
|
|
||||||
ggml_set_name(top_k, "top_k");
|
ggml_set_name(top_k, "top_k");
|
||||||
|
|
||||||
// top_k is a view of argsort - check if backend supports the underlying argsort operation
|
|
||||||
// by checking the source tensor (which is the argsort result)
|
|
||||||
if (ctx_data->device && top_k->src[0] && !ggml_backend_dev_supports_op(ctx_data->device, top_k->src[0])) {
|
|
||||||
fprintf(stderr, "Warning: backend does not support argsort operation required for top-k sampling\n");
|
|
||||||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
|
||||||
}
|
|
||||||
|
|
||||||
data->candidates = top_k;
|
|
||||||
|
|
||||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]);
|
||||||
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
||||||
ggml_set_name(top_k_rows, "top_k_rows");
|
ggml_set_name(top_k_rows, "top_k_rows");
|
||||||
|
|
||||||
data->logits = ggml_reshape_1d(ctx, top_k_rows, ctx_data->k);
|
data->candidates = top_k;
|
||||||
|
data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k);
|
||||||
|
|
||||||
GGML_UNUSED(gf);
|
GGML_UNUSED(gf);
|
||||||
}
|
}
|
||||||
|
|
@ -1099,29 +1241,30 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) {
|
||||||
const bool is_empty = (k <= 0);
|
const bool is_empty = (k <= 0);
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("top-k?");
|
return llama_sampler_init_empty("?top-k");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_k_i,
|
/* .iface = */ &llama_sampler_top_k_i,
|
||||||
/* .ctx = */ new llama_sampler_top_k {
|
/* .ctx = */ new llama_sampler_top_k {
|
||||||
/* .k = */ k,
|
("top-k"),
|
||||||
/* .device = */ nullptr,
|
/* .k = */ k,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// top-p
|
// top-p
|
||||||
|
|
||||||
struct llama_sampler_top_p {
|
struct llama_sampler_top_p : public llama_sampler_backend {
|
||||||
const float p;
|
const float p;
|
||||||
const size_t min_keep;
|
const size_t min_keep;
|
||||||
|
|
||||||
std::vector<llama_token_data> buf_sort;
|
std::vector<llama_token_data> buf_sort;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) {
|
||||||
return "top-p";
|
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
||||||
|
return sctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -1188,11 +1331,15 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_top_p *) smpl->ctx;
|
delete (llama_sampler_top_p *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_p_backend_init(
|
static bool llama_sampler_top_p_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
GGML_UNUSED(smpl);
|
|
||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
|
|
||||||
|
auto * sctx = (llama_sampler_top_p *) smpl->ctx;
|
||||||
|
sctx->init(true);
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_top_p_backend_apply(
|
static void llama_sampler_top_p_backend_apply(
|
||||||
|
|
@ -1287,12 +1434,13 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
||||||
const bool is_empty = p >= 1.0f;
|
const bool is_empty = p >= 1.0f;
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("top-p?");
|
return llama_sampler_init_empty("?top-p");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_top_p_i,
|
/* .iface = */ &llama_sampler_top_p_i,
|
||||||
/* .ctx = */ new llama_sampler_top_p {
|
/* .ctx = */ new llama_sampler_top_p {
|
||||||
|
("top-p"),
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
/* .buf_sort = */ {},
|
/* .buf_sort = */ {},
|
||||||
|
|
@ -1302,13 +1450,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) {
|
||||||
|
|
||||||
// min-p
|
// min-p
|
||||||
|
|
||||||
struct llama_sampler_min_p {
|
struct llama_sampler_min_p : public llama_sampler_backend {
|
||||||
const float p;
|
const float p;
|
||||||
const size_t min_keep;
|
const size_t min_keep;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) {
|
||||||
return "min-p";
|
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
||||||
|
return sctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -1374,11 +1523,16 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_min_p *) smpl->ctx;
|
delete (llama_sampler_min_p *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_min_p_backend_init(
|
static bool llama_sampler_min_p_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
GGML_UNUSED(smpl);
|
|
||||||
GGML_UNUSED(buft);
|
GGML_UNUSED(buft);
|
||||||
|
|
||||||
|
auto * sctx = (llama_sampler_min_p *) smpl->ctx;
|
||||||
|
|
||||||
|
sctx->init(true);
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_min_p_backend_apply(
|
static void llama_sampler_min_p_backend_apply(
|
||||||
|
|
@ -1441,12 +1595,13 @@ struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) {
|
||||||
const bool is_empty = (p <= 0.0f);
|
const bool is_empty = (p <= 0.0f);
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("min-p?");
|
return llama_sampler_init_empty("?min-p");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_min_p_i,
|
/* .iface = */ &llama_sampler_min_p_i,
|
||||||
/* .ctx = */ new llama_sampler_min_p {
|
/* .ctx = */ new llama_sampler_min_p {
|
||||||
|
("min-p"),
|
||||||
/* .p = */ p,
|
/* .p = */ p,
|
||||||
/* .min_keep = */ min_keep,
|
/* .min_keep = */ min_keep,
|
||||||
}
|
}
|
||||||
|
|
@ -1550,7 +1705,7 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||||
const bool is_empty = (p >= 1.0f);
|
const bool is_empty = (p >= 1.0f);
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("typical?");
|
return llama_sampler_init_empty("?typical");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
|
|
@ -1564,12 +1719,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) {
|
||||||
|
|
||||||
// temp
|
// temp
|
||||||
|
|
||||||
struct llama_sampler_temp {
|
struct llama_sampler_temp : public llama_sampler_backend {
|
||||||
const float temp;
|
const float temp;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) {
|
||||||
return "temp";
|
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
return sctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -1616,13 +1772,25 @@ static void llama_sampler_backend_temp_sampling(
|
||||||
GGML_UNUSED(gf);
|
GGML_UNUSED(gf);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_sampler_temp_backend_init(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
ggml_backend_buffer_type_t buft) {
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
|
||||||
|
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
|
|
||||||
|
sctx->init(true);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_temp_backend_apply(
|
static void llama_sampler_temp_backend_apply(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
auto * ctx_data = (llama_sampler_temp *) smpl->ctx;
|
auto * sctx = (llama_sampler_temp *) smpl->ctx;
|
||||||
llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
|
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_temp_i = {
|
static struct llama_sampler_i llama_sampler_temp_i = {
|
||||||
|
|
@ -1632,7 +1800,7 @@ static struct llama_sampler_i llama_sampler_temp_i = {
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_temp_clone,
|
/* .clone = */ llama_sampler_temp_clone,
|
||||||
/* .free = */ llama_sampler_temp_free,
|
/* .free = */ llama_sampler_temp_free,
|
||||||
/* .backend_init = */ nullptr,
|
/* .backend_init = */ llama_sampler_temp_backend_init,
|
||||||
/* .backend_accept = */ nullptr,
|
/* .backend_accept = */ nullptr,
|
||||||
/* .backend_apply = */ llama_sampler_temp_backend_apply,
|
/* .backend_apply = */ llama_sampler_temp_backend_apply,
|
||||||
/* .backend_set_input = */ nullptr,
|
/* .backend_set_input = */ nullptr,
|
||||||
|
|
@ -1642,12 +1810,13 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||||
const bool is_empty = temp == 1.0f;
|
const bool is_empty = temp == 1.0f;
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("temp?");
|
return llama_sampler_init_empty("?temp");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_i,
|
/* .iface = */ &llama_sampler_temp_i,
|
||||||
/* .ctx = */ new llama_sampler_temp {
|
/* .ctx = */ new llama_sampler_temp {
|
||||||
|
("temp"),
|
||||||
/*.temp = */ temp,
|
/*.temp = */ temp,
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
@ -1655,14 +1824,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) {
|
||||||
|
|
||||||
// temp-ext
|
// temp-ext
|
||||||
|
|
||||||
struct llama_sampler_temp_ext {
|
struct llama_sampler_temp_ext : public llama_sampler_backend {
|
||||||
const float temp;
|
const float temp;
|
||||||
const float delta;
|
const float delta;
|
||||||
const float exponent;
|
const float exponent;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) {
|
||||||
return "temp-ext";
|
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
|
return sctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -1745,22 +1915,34 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) {
|
||||||
delete (llama_sampler_temp_ext *) smpl->ctx;
|
delete (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool llama_sampler_temp_ext_backend_init(
|
||||||
|
struct llama_sampler * smpl,
|
||||||
|
ggml_backend_buffer_type_t buft) {
|
||||||
|
GGML_UNUSED(buft);
|
||||||
|
|
||||||
|
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
|
|
||||||
|
sctx->init(true);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
static void llama_sampler_temp_ext_backend_apply(
|
static void llama_sampler_temp_ext_backend_apply(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_cgraph * gf,
|
struct ggml_cgraph * gf,
|
||||||
struct llama_sampler_data * data) {
|
struct llama_sampler_data * data) {
|
||||||
auto * ctx_data = (llama_sampler_temp_ext *) smpl->ctx;
|
auto * sctx = (llama_sampler_temp_ext *) smpl->ctx;
|
||||||
|
|
||||||
// Revert to standard temperature scaling if delta or temp are non-positive.
|
// Revert to standard temperature scaling if delta or temp are non-positive.
|
||||||
if (ctx_data->delta <= 0.0f || ctx_data->temp <= 0.0f) {
|
if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) {
|
||||||
llama_sampler_backend_temp_sampling(ctx, gf, data, ctx_data->temp);
|
llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate min_temp, max_temp, and max_entropy.
|
// Calculate min_temp, max_temp, and max_entropy.
|
||||||
const float min_temp = std::max(0.0f, ctx_data->temp - ctx_data->delta);
|
const float min_temp = std::max(0.0f, sctx->temp - sctx->delta);
|
||||||
const float max_temp = ctx_data->temp + ctx_data->delta;
|
const float max_temp = sctx->temp + sctx->delta;
|
||||||
const float max_entropy = logf(data->logits->ne[0]);
|
const float max_entropy = logf(data->logits->ne[0]);
|
||||||
|
|
||||||
// Calculate the probabilities.
|
// Calculate the probabilities.
|
||||||
|
|
@ -1791,7 +1973,7 @@ static void llama_sampler_temp_ext_backend_apply(
|
||||||
// Calculate powf(normalized_entropy, exponent) as
|
// Calculate powf(normalized_entropy, exponent) as
|
||||||
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
|
// norm_entropy^exponent = exp(exponent * log(norm_entropy))
|
||||||
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
|
struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy);
|
||||||
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, ctx_data->exponent);
|
struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent);
|
||||||
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
|
struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log);
|
||||||
// With pow_entropy computed we can now compute dyn_temp, scaling by
|
// With pow_entropy computed we can now compute dyn_temp, scaling by
|
||||||
// (max_temp - min_temp) and then adding min_temp.
|
// (max_temp - min_temp) and then adding min_temp.
|
||||||
|
|
@ -1815,7 +1997,7 @@ static struct llama_sampler_i llama_sampler_temp_ext_i = {
|
||||||
/* .reset = */ nullptr,
|
/* .reset = */ nullptr,
|
||||||
/* .clone = */ llama_sampler_temp_ext_clone,
|
/* .clone = */ llama_sampler_temp_ext_clone,
|
||||||
/* .free = */ llama_sampler_temp_ext_free,
|
/* .free = */ llama_sampler_temp_ext_free,
|
||||||
/* .backend_init = */ nullptr,
|
/* .backend_init = */ llama_sampler_temp_ext_backend_init,
|
||||||
/* .backend_accept = */ nullptr,
|
/* .backend_accept = */ nullptr,
|
||||||
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
|
/* .backend_apply = */ llama_sampler_temp_ext_backend_apply,
|
||||||
/* .backend_set_input = */ nullptr,
|
/* .backend_set_input = */ nullptr,
|
||||||
|
|
@ -1825,12 +2007,13 @@ struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, floa
|
||||||
const bool is_empty = temp == 1.0f && delta <= 0.0f;
|
const bool is_empty = temp == 1.0f && delta <= 0.0f;
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("temp-ext?");
|
return llama_sampler_init_empty("?temp-ext");
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * res = llama_sampler_init(
|
auto * res = llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_temp_ext_i,
|
/* .iface = */ &llama_sampler_temp_ext_i,
|
||||||
/* .ctx = */ new llama_sampler_temp_ext {
|
/* .ctx = */ new llama_sampler_temp_ext {
|
||||||
|
("temp-ext"),
|
||||||
/* .temp = */ temp,
|
/* .temp = */ temp,
|
||||||
/* .delta = */ delta,
|
/* .delta = */ delta,
|
||||||
/* .exponent = */ exponent,
|
/* .exponent = */ exponent,
|
||||||
|
|
@ -1931,7 +2114,7 @@ struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep,
|
||||||
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
const bool is_empty = (p <= 0.0f || t > 0.5f);
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("xtc?");
|
return llama_sampler_init_empty("?xtc");
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto seed_cur = get_rng_seed(seed);
|
const auto seed_cur = get_rng_seed(seed);
|
||||||
|
|
@ -2492,7 +2675,7 @@ struct llama_sampler * llama_sampler_init_penalties(
|
||||||
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f));
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("penalties?");
|
return llama_sampler_init_empty("?penalties");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
|
|
@ -2585,7 +2768,7 @@ struct llama_sampler * llama_sampler_init_top_n_sigma(float n) {
|
||||||
const bool is_empty = (n <= 0.0f);
|
const bool is_empty = (n <= 0.0f);
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("top-n-sigma?");
|
return llama_sampler_init_empty("?top-n-sigma");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
|
|
@ -2930,7 +3113,7 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab,
|
||||||
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0);
|
||||||
|
|
||||||
if (!dry_enabled) {
|
if (!dry_enabled) {
|
||||||
return llama_sampler_init_empty("dry?");
|
return llama_sampler_init_empty("?dry");
|
||||||
}
|
}
|
||||||
|
|
||||||
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) {
|
||||||
|
|
@ -3003,7 +3186,7 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa
|
||||||
|
|
||||||
// logit-bias
|
// logit-bias
|
||||||
|
|
||||||
struct llama_sampler_logit_bias {
|
struct llama_sampler_logit_bias : public llama_sampler_backend {
|
||||||
const int32_t n_vocab;
|
const int32_t n_vocab;
|
||||||
|
|
||||||
const std::vector<llama_logit_bias> logit_bias;
|
const std::vector<llama_logit_bias> logit_bias;
|
||||||
|
|
@ -3016,8 +3199,9 @@ struct llama_sampler_logit_bias {
|
||||||
ggml_backend_buffer_ptr inp_buf;
|
ggml_backend_buffer_ptr inp_buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) {
|
static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) {
|
||||||
return "logit-bias";
|
auto * ctx = (llama_sampler_logit_bias *) smpl->ctx;
|
||||||
|
return ctx->get_name();
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) {
|
||||||
|
|
@ -3097,13 +3281,15 @@ static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * sm
|
||||||
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
ggml_backend_tensor_set(sctx->inp_logit_bias, logit_bias_sparse.data(), 0, ggml_nbytes(sctx->inp_logit_bias));
|
||||||
}
|
}
|
||||||
|
|
||||||
static void llama_sampler_logit_bias_backend_init(
|
static bool llama_sampler_logit_bias_backend_init(
|
||||||
struct llama_sampler * smpl,
|
struct llama_sampler * smpl,
|
||||||
ggml_backend_buffer_type_t buft) {
|
ggml_backend_buffer_type_t buft) {
|
||||||
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
auto * sctx = (llama_sampler_logit_bias *) smpl->ctx;
|
||||||
|
|
||||||
|
sctx->init(true);
|
||||||
|
|
||||||
if (sctx->logit_bias.empty()) {
|
if (sctx->logit_bias.empty()) {
|
||||||
return;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_init_params params = {
|
ggml_init_params params = {
|
||||||
|
|
@ -3120,6 +3306,8 @@ static void llama_sampler_logit_bias_backend_init(
|
||||||
|
|
||||||
// Allocate all tensors from our context to the backend
|
// Allocate all tensors from our context to the backend
|
||||||
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft));
|
||||||
|
|
||||||
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
static struct llama_sampler_i llama_sampler_logit_bias_i = {
|
||||||
|
|
@ -3142,12 +3330,13 @@ struct llama_sampler * llama_sampler_init_logit_bias(
|
||||||
const bool is_empty = n_logit_bias <= 0;
|
const bool is_empty = n_logit_bias <= 0;
|
||||||
|
|
||||||
if (is_empty) {
|
if (is_empty) {
|
||||||
return llama_sampler_init_empty("logit-bias?");
|
return llama_sampler_init_empty("?logit-bias");
|
||||||
}
|
}
|
||||||
|
|
||||||
return llama_sampler_init(
|
return llama_sampler_init(
|
||||||
/* .iface = */ &llama_sampler_logit_bias_i,
|
/* .iface = */ &llama_sampler_logit_bias_i,
|
||||||
/* .ctx = */ new llama_sampler_logit_bias {
|
/* .ctx = */ new llama_sampler_logit_bias {
|
||||||
|
("logit-bias"),
|
||||||
/* .n_vocab = */ n_vocab,
|
/* .n_vocab = */ n_vocab,
|
||||||
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
/* .logit_bias = */ std::vector<llama_logit_bias>(logit_bias, logit_bias + n_logit_bias),
|
||||||
/* .to_search = */ {},
|
/* .to_search = */ {},
|
||||||
|
|
@ -3407,7 +3596,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) {
|
||||||
if (smpl->iface == &llama_sampler_chain_i) {
|
if (smpl->iface == &llama_sampler_chain_i) {
|
||||||
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
const auto * ctx = (const llama_sampler_chain *) smpl->ctx;
|
||||||
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) {
|
||||||
const uint32_t seed = llama_sampler_get_seed(*it);
|
const uint32_t seed = llama_sampler_get_seed(it->ptr);
|
||||||
if (seed != LLAMA_DEFAULT_SEED) {
|
if (seed != LLAMA_DEFAULT_SEED) {
|
||||||
return seed;
|
return seed;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,16 @@ struct llama_grammar;
|
||||||
struct llama_sampler_chain {
|
struct llama_sampler_chain {
|
||||||
llama_sampler_chain_params params;
|
llama_sampler_chain_params params;
|
||||||
|
|
||||||
std::vector<struct llama_sampler *> samplers;
|
// has .backend_init() been called?
|
||||||
|
bool is_init = false;
|
||||||
|
|
||||||
|
struct info {
|
||||||
|
bool is_backend;
|
||||||
|
|
||||||
|
llama_sampler * ptr;
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<info> samplers;
|
||||||
|
|
||||||
// timing
|
// timing
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1028,10 +1028,9 @@ struct server_context_impl {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get()));
|
||||||
|
|
||||||
llama_sampler * backend_chain = common_sampler_chain_backend(slot.smpl.get());
|
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||||
llama_set_sampler(ctx, slot.id, backend_chain);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize draft batch
|
// initialize draft batch
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue