llama : reserve on sampler changes
This commit is contained in:
parent
5260bb79c0
commit
0c0d0fdc30
|
|
@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
|
||||||
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: temporarily gated behind a flag
|
|
||||||
if (params.sampling.backend_sampling) {
|
if (params.sampling.backend_sampling) {
|
||||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||||
|
|
|
||||||
|
|
@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
|
||||||
sampler_configs.push_back({ i, smpl });
|
sampler_configs.push_back({ i, smpl });
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: temporarily gated behind a flag
|
|
||||||
if (params.sampling.backend_sampling) {
|
if (params.sampling.backend_sampling) {
|
||||||
ctx_params.samplers = sampler_configs.data();
|
ctx_params.samplers = sampler_configs.data();
|
||||||
ctx_params.n_samplers = sampler_configs.size();
|
ctx_params.n_samplers = sampler_configs.size();
|
||||||
|
|
|
||||||
|
|
@ -1255,7 +1255,6 @@ extern "C" {
|
||||||
// [EXPERIMENTAL]
|
// [EXPERIMENTAL]
|
||||||
// attach a sampler to the context
|
// attach a sampler to the context
|
||||||
// note: prefer initializing the context with llama_context_params.samplers when possible
|
// note: prefer initializing the context with llama_context_params.samplers when possible
|
||||||
// note: changing the samplers of a context can cause graph reallocations and degraded performance
|
|
||||||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||||
|
|
||||||
// mirror of llama_sampler_i:
|
// mirror of llama_sampler_i:
|
||||||
|
|
|
||||||
|
|
@ -340,7 +340,7 @@ llama_context::llama_context(
|
||||||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||||
}
|
}
|
||||||
|
|
||||||
reserve();
|
sched_reserve();
|
||||||
|
|
||||||
if (!cparams.flash_attn) {
|
if (!cparams.flash_attn) {
|
||||||
if (ggml_is_quantized(params.type_v)) {
|
if (ggml_is_quantized(params.type_v)) {
|
||||||
|
|
@ -380,7 +380,13 @@ llama_context::~llama_context() {
|
||||||
ggml_opt_free(opt_ctx);
|
ggml_opt_free(opt_ctx);
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::reserve() {
|
void llama_context::sched_reserve() {
|
||||||
|
if (!sched_need_reserve) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
sched_need_reserve = false;
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
|
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
|
||||||
|
|
||||||
synchronize();
|
synchronize();
|
||||||
|
|
@ -408,10 +414,8 @@ void llama_context::reserve() {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cross.v_embd.clear();
|
|
||||||
|
|
||||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||||
n_outputs = n_seqs;
|
const int n_outputs = n_seqs;
|
||||||
|
|
||||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||||
|
|
||||||
|
|
@ -983,7 +987,7 @@ void llama_context::set_embeddings(bool value) {
|
||||||
cparams.embeddings = value;
|
cparams.embeddings = value;
|
||||||
|
|
||||||
// TODO: not sure yet if we want to reserve here
|
// TODO: not sure yet if we want to reserve here
|
||||||
//reserve();
|
//sched_need_reserve = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::set_causal_attn(bool value) {
|
void llama_context::set_causal_attn(bool value) {
|
||||||
|
|
@ -995,17 +999,27 @@ void llama_context::set_causal_attn(bool value) {
|
||||||
|
|
||||||
cparams.causal_attn = value;
|
cparams.causal_attn = value;
|
||||||
|
|
||||||
reserve();
|
sched_need_reserve = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
void llama_context::set_warmup(bool value) {
|
void llama_context::set_warmup(bool value) {
|
||||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||||
|
|
||||||
|
if (cparams.warmup == value) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
cparams.warmup = value;
|
cparams.warmup = value;
|
||||||
|
|
||||||
|
sched_need_reserve = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||||
LLAMA_LOG_ERROR("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
if (!sampler && sampling.samplers.count(seq_id) == 0) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||||
|
|
||||||
const bool can_offload =
|
const bool can_offload =
|
||||||
sampler &&
|
sampler &&
|
||||||
|
|
@ -1024,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||||
|
|
||||||
sampling.samplers[seq_id] = sampler;
|
sampling.samplers[seq_id] = sampler;
|
||||||
|
|
||||||
|
sched_need_reserve = true;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (sampler && !can_offload) {
|
if (sampler && !can_offload) {
|
||||||
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
||||||
|
|
||||||
|
if (sampling.samplers.count(seq_id) > 0) {
|
||||||
|
sched_need_reserve = true;
|
||||||
|
}
|
||||||
|
|
||||||
sampling.samplers.erase(seq_id);
|
sampling.samplers.erase(seq_id);
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -1053,7 +1073,7 @@ void llama_context::set_adapter_lora(
|
||||||
|
|
||||||
loras[adapter] = scale;
|
loras[adapter] = scale;
|
||||||
|
|
||||||
reserve();
|
sched_need_reserve = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_context::rm_adapter_lora(
|
bool llama_context::rm_adapter_lora(
|
||||||
|
|
@ -1064,7 +1084,7 @@ bool llama_context::rm_adapter_lora(
|
||||||
if (it != loras.end()) {
|
if (it != loras.end()) {
|
||||||
loras.erase(it);
|
loras.erase(it);
|
||||||
|
|
||||||
reserve();
|
sched_need_reserve = true;
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -1081,7 +1101,7 @@ void llama_context::clear_adapter_lora() {
|
||||||
|
|
||||||
loras.clear();
|
loras.clear();
|
||||||
|
|
||||||
reserve();
|
sched_need_reserve = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llama_context::apply_adapter_cvec(
|
bool llama_context::apply_adapter_cvec(
|
||||||
|
|
@ -1196,6 +1216,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||||
// TODO: this clear of the buffer can easily be forgotten - need something better
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
|
|
||||||
|
sched_reserve();
|
||||||
|
|
||||||
n_queued_tokens += n_tokens;
|
n_queued_tokens += n_tokens;
|
||||||
|
|
||||||
// reserve output buffer
|
// reserve output buffer
|
||||||
|
|
@ -1235,7 +1257,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||||
|
|
||||||
// extract logits
|
// extract logits
|
||||||
if (logits && t_logits) {
|
if (logits && t_logits) {
|
||||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||||
GGML_ASSERT(backend_res != nullptr);
|
GGML_ASSERT(backend_res != nullptr);
|
||||||
GGML_ASSERT(logits != nullptr);
|
GGML_ASSERT(logits != nullptr);
|
||||||
|
|
@ -1509,6 +1531,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
output_swaps.clear();
|
output_swaps.clear();
|
||||||
|
|
||||||
|
sched_reserve();
|
||||||
|
|
||||||
bool did_optimize = false;
|
bool did_optimize = false;
|
||||||
|
|
||||||
// handle any pending shifts/copies
|
// handle any pending shifts/copies
|
||||||
|
|
|
||||||
|
|
@ -40,14 +40,13 @@ struct llama_context {
|
||||||
|
|
||||||
~llama_context();
|
~llama_context();
|
||||||
|
|
||||||
// reserve a new backend scheduler
|
// reserve a new backend scheduler (if needed)
|
||||||
// recommended to call whenver the context changes in such a way that the compute graph is modified.
|
// for example, when:
|
||||||
// for example:
|
|
||||||
// - changing loras
|
// - changing loras
|
||||||
// - changing samplers
|
// - changing samplers
|
||||||
// - changing attention type
|
// - changing attention type
|
||||||
// - etc.
|
// - etc.
|
||||||
void reserve();
|
void sched_reserve();
|
||||||
|
|
||||||
void synchronize();
|
void synchronize();
|
||||||
|
|
||||||
|
|
@ -323,6 +322,8 @@ private:
|
||||||
|
|
||||||
ggml_backend_sched_ptr sched;
|
ggml_backend_sched_ptr sched;
|
||||||
|
|
||||||
|
bool sched_need_reserve = true;
|
||||||
|
|
||||||
ggml_backend_t backend_cpu = nullptr;
|
ggml_backend_t backend_cpu = nullptr;
|
||||||
std::vector<ggml_backend_ptr> backends;
|
std::vector<ggml_backend_ptr> backends;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1182,7 +1182,7 @@ private:
|
||||||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||||
|
|
||||||
// initialize samplers
|
// initialize samplers
|
||||||
{
|
if (task.uses_sampling()) {
|
||||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||||
|
|
||||||
if (slot.smpl == nullptr) {
|
if (slot.smpl == nullptr) {
|
||||||
|
|
@ -1211,6 +1211,8 @@ private:
|
||||||
}
|
}
|
||||||
|
|
||||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||||
|
} else {
|
||||||
|
slot.smpl.reset();
|
||||||
}
|
}
|
||||||
|
|
||||||
// initialize draft batch
|
// initialize draft batch
|
||||||
|
|
@ -2593,6 +2595,12 @@ private:
|
||||||
llama_set_embeddings(ctx, slot_batched->need_embd());
|
llama_set_embeddings(ctx, slot_batched->need_embd());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for (auto & slot : slots) {
|
||||||
|
if (!slot.is_processing() || !slot.smpl) {
|
||||||
|
llama_set_sampler(ctx, slot.id, nullptr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (batch.n_tokens == 0) {
|
if (batch.n_tokens == 0) {
|
||||||
SRV_WRN("%s", "no tokens to decode\n");
|
SRV_WRN("%s", "no tokens to decode\n");
|
||||||
}
|
}
|
||||||
|
|
@ -2727,6 +2735,8 @@ private:
|
||||||
continue; // continue loop of slots
|
continue; // continue loop of slots
|
||||||
}
|
}
|
||||||
|
|
||||||
|
GGML_ASSERT(slot.task->uses_sampling());
|
||||||
|
|
||||||
// prompt evaluated for next-token prediction
|
// prompt evaluated for next-token prediction
|
||||||
slot.state = SLOT_STATE_GENERATING;
|
slot.state = SLOT_STATE_GENERATING;
|
||||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||||
|
|
|
||||||
|
|
@ -156,6 +156,11 @@ struct server_task {
|
||||||
return tokens.size();
|
return tokens.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool uses_sampling() const {
|
||||||
|
return type != SERVER_TASK_TYPE_EMBEDDING &&
|
||||||
|
type != SERVER_TASK_TYPE_RERANK;
|
||||||
|
}
|
||||||
|
|
||||||
static task_params params_from_json_cmpl(
|
static task_params params_from_json_cmpl(
|
||||||
const llama_vocab * vocab,
|
const llama_vocab * vocab,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue