llama : reserve on sampler changes
This commit is contained in:
parent
5260bb79c0
commit
0c0d0fdc30
|
|
@ -1172,7 +1172,6 @@ common_init_result::common_init_result(common_params & params) :
|
|||
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
cparams.samplers = pimpl->samplers_seq_config.data();
|
||||
cparams.n_samplers = pimpl->samplers_seq_config.size();
|
||||
|
|
|
|||
|
|
@ -81,7 +81,6 @@ int main(int argc, char ** argv) {
|
|||
sampler_configs.push_back({ i, smpl });
|
||||
}
|
||||
|
||||
// TODO: temporarily gated behind a flag
|
||||
if (params.sampling.backend_sampling) {
|
||||
ctx_params.samplers = sampler_configs.data();
|
||||
ctx_params.n_samplers = sampler_configs.size();
|
||||
|
|
|
|||
|
|
@ -1255,7 +1255,6 @@ extern "C" {
|
|||
// [EXPERIMENTAL]
|
||||
// attach a sampler to the context
|
||||
// note: prefer initializing the context with llama_context_params.samplers when possible
|
||||
// note: changing the samplers of a context can cause graph reallocations and degraded performance
|
||||
LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl);
|
||||
|
||||
// mirror of llama_sampler_i:
|
||||
|
|
|
|||
|
|
@ -340,7 +340,7 @@ llama_context::llama_context(
|
|||
LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get()));
|
||||
}
|
||||
|
||||
reserve();
|
||||
sched_reserve();
|
||||
|
||||
if (!cparams.flash_attn) {
|
||||
if (ggml_is_quantized(params.type_v)) {
|
||||
|
|
@ -380,7 +380,13 @@ llama_context::~llama_context() {
|
|||
ggml_opt_free(opt_ctx);
|
||||
}
|
||||
|
||||
void llama_context::reserve() {
|
||||
void llama_context::sched_reserve() {
|
||||
if (!sched_need_reserve) {
|
||||
return;
|
||||
}
|
||||
|
||||
sched_need_reserve = false;
|
||||
|
||||
LLAMA_LOG_INFO("%s: reserving ...\n", __func__);
|
||||
|
||||
synchronize();
|
||||
|
|
@ -408,10 +414,8 @@ void llama_context::reserve() {
|
|||
}
|
||||
}
|
||||
|
||||
cross.v_embd.clear();
|
||||
|
||||
// avoid reserving graphs with zero outputs - assume one output per sequence
|
||||
n_outputs = n_seqs;
|
||||
const int n_outputs = n_seqs;
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs);
|
||||
|
||||
|
|
@ -983,7 +987,7 @@ void llama_context::set_embeddings(bool value) {
|
|||
cparams.embeddings = value;
|
||||
|
||||
// TODO: not sure yet if we want to reserve here
|
||||
//reserve();
|
||||
//sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_causal_attn(bool value) {
|
||||
|
|
@ -995,17 +999,27 @@ void llama_context::set_causal_attn(bool value) {
|
|||
|
||||
cparams.causal_attn = value;
|
||||
|
||||
reserve();
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
void llama_context::set_warmup(bool value) {
|
||||
LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value);
|
||||
|
||||
if (cparams.warmup == value) {
|
||||
return;
|
||||
}
|
||||
|
||||
cparams.warmup = value;
|
||||
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
LLAMA_LOG_ERROR("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
if (!sampler && sampling.samplers.count(seq_id) == 0) {
|
||||
return true;
|
||||
}
|
||||
|
||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
|
||||
const bool can_offload =
|
||||
sampler &&
|
||||
|
|
@ -1024,12 +1038,18 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
|||
|
||||
sampling.samplers[seq_id] = sampler;
|
||||
|
||||
sched_need_reserve = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
if (sampler && !can_offload) {
|
||||
LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id);
|
||||
|
||||
if (sampling.samplers.count(seq_id) > 0) {
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
sampling.samplers.erase(seq_id);
|
||||
|
||||
return false;
|
||||
|
|
@ -1053,7 +1073,7 @@ void llama_context::set_adapter_lora(
|
|||
|
||||
loras[adapter] = scale;
|
||||
|
||||
reserve();
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
bool llama_context::rm_adapter_lora(
|
||||
|
|
@ -1064,7 +1084,7 @@ bool llama_context::rm_adapter_lora(
|
|||
if (it != loras.end()) {
|
||||
loras.erase(it);
|
||||
|
||||
reserve();
|
||||
sched_need_reserve = true;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
|
@ -1081,7 +1101,7 @@ void llama_context::clear_adapter_lora() {
|
|||
|
||||
loras.clear();
|
||||
|
||||
reserve();
|
||||
sched_need_reserve = true;
|
||||
}
|
||||
|
||||
bool llama_context::apply_adapter_cvec(
|
||||
|
|
@ -1196,6 +1216,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||
embd_seq.clear();
|
||||
|
||||
sched_reserve();
|
||||
|
||||
n_queued_tokens += n_tokens;
|
||||
|
||||
// reserve output buffer
|
||||
|
|
@ -1235,7 +1257,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd();
|
||||
|
||||
// extract logits
|
||||
if (logits && t_logits) {
|
||||
if (logits && t_logits) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
|
@ -1509,6 +1531,8 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
embd_seq.clear();
|
||||
output_swaps.clear();
|
||||
|
||||
sched_reserve();
|
||||
|
||||
bool did_optimize = false;
|
||||
|
||||
// handle any pending shifts/copies
|
||||
|
|
|
|||
|
|
@ -40,14 +40,13 @@ struct llama_context {
|
|||
|
||||
~llama_context();
|
||||
|
||||
// reserve a new backend scheduler
|
||||
// recommended to call whenver the context changes in such a way that the compute graph is modified.
|
||||
// for example:
|
||||
// reserve a new backend scheduler (if needed)
|
||||
// for example, when:
|
||||
// - changing loras
|
||||
// - changing samplers
|
||||
// - changing attention type
|
||||
// - etc.
|
||||
void reserve();
|
||||
void sched_reserve();
|
||||
|
||||
void synchronize();
|
||||
|
||||
|
|
@ -323,6 +322,8 @@ private:
|
|||
|
||||
ggml_backend_sched_ptr sched;
|
||||
|
||||
bool sched_need_reserve = true;
|
||||
|
||||
ggml_backend_t backend_cpu = nullptr;
|
||||
std::vector<ggml_backend_ptr> backends;
|
||||
|
||||
|
|
|
|||
|
|
@ -1182,7 +1182,7 @@ private:
|
|||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||
|
||||
// initialize samplers
|
||||
{
|
||||
if (task.uses_sampling()) {
|
||||
slot.smpl.reset(common_sampler_init(model, task.params.sampling));
|
||||
|
||||
if (slot.smpl == nullptr) {
|
||||
|
|
@ -1211,6 +1211,8 @@ private:
|
|||
}
|
||||
|
||||
SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str());
|
||||
} else {
|
||||
slot.smpl.reset();
|
||||
}
|
||||
|
||||
// initialize draft batch
|
||||
|
|
@ -2593,6 +2595,12 @@ private:
|
|||
llama_set_embeddings(ctx, slot_batched->need_embd());
|
||||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
if (!slot.is_processing() || !slot.smpl) {
|
||||
llama_set_sampler(ctx, slot.id, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
SRV_WRN("%s", "no tokens to decode\n");
|
||||
}
|
||||
|
|
@ -2727,6 +2735,8 @@ private:
|
|||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
GGML_ASSERT(slot.task->uses_sampling());
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
|
|
|
|||
|
|
@ -156,6 +156,11 @@ struct server_task {
|
|||
return tokens.size();
|
||||
}
|
||||
|
||||
bool uses_sampling() const {
|
||||
return type != SERVER_TASK_TYPE_EMBEDDING &&
|
||||
type != SERVER_TASK_TYPE_RERANK;
|
||||
}
|
||||
|
||||
static task_params params_from_json_cmpl(
|
||||
const llama_vocab * vocab,
|
||||
const common_params & params_base,
|
||||
|
|
|
|||
Loading…
Reference in New Issue