sampling : support intermixed backend/cpu samplers
This commit updates the backend sampling implementation to support intermixed usage of backend and CPU samplers within the same batch. The initial implementation was developed as an all-or-nothing solution: either perform backend sampling for the entire batch, or perform CPU sampling for the entire batch. The motivation for this change is to support batches with mixed sequences. For example, we may have a backend sampler configured for sequence 0, while sequence 1 in the same batch uses CPU sampling. This was not supported in the initial implementation. This issue manifested in llama-server with the webui: decoding with backend samplers would work initially, but after changing to CPU sampling, a slot (sequence) could still be using a backend sampler. This meant that logits in output_reserve would not be allocated, resulting in an error. The solution in this commit inspects the batch to determine which sampling modes are needed and allocates buffers accordingly. However, there is a known inefficiency: when we have intermixed backend/CPU samplers in the same batch, we currently copy all logits to the host, even for sequences using backend samplers. Added test_backend_cpu_mixed_batch to verify correct behavior with mixed backend/CPU samplers in a single batch, including dynamic sampler switching between decode calls.
This commit is contained in:
parent
f9889cf1c7
commit
74be332e24
|
|
@ -212,7 +212,10 @@ llama_context::llama_context(
|
|||
// graph outputs buffer
|
||||
{
|
||||
// resized during inference when a batch uses more outputs
|
||||
if (output_reserve(params.n_seq_max) < params.n_seq_max) {
|
||||
// Create a dummy batch for initialization.
|
||||
llama_batch dummy_batch = {};
|
||||
dummy_batch.n_tokens = 0;
|
||||
if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) {
|
||||
throw std::runtime_error("failed to reserve initial output buffer");
|
||||
}
|
||||
|
||||
|
|
@ -1075,7 +1078,7 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
n_queued_tokens += n_tokens;
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_tokens) < n_tokens) {
|
||||
if (output_reserve(n_tokens, batch_inp) < n_tokens) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens);
|
||||
return -2;
|
||||
};
|
||||
|
|
@ -1403,7 +1406,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
||||
return -2;
|
||||
};
|
||||
|
|
@ -1493,82 +1496,83 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
}
|
||||
|
||||
if (!backend_has_sampled) {
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
auto * t_logits = res->get_logits();
|
||||
auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr;
|
||||
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
if (t_embd && res->get_embd_pooled()) {
|
||||
t_embd = res->get_embd_pooled();
|
||||
}
|
||||
|
||||
// extract logits
|
||||
// For multipsequence batches that mix backend samplers and CPU sampler
|
||||
// this is currently inefficient as we copy all logits even for the
|
||||
// backend sampled tokens.
|
||||
if (logits && t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// extract logits
|
||||
if (t_logits && n_outputs > 0) {
|
||||
ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits);
|
||||
GGML_ASSERT(backend_res != nullptr);
|
||||
GGML_ASSERT(logits != nullptr);
|
||||
// extract embeddings
|
||||
if (embd && t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
float * logits_out = logits + n_outputs_prev*n_vocab;
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size);
|
||||
ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float));
|
||||
}
|
||||
}
|
||||
|
||||
// extract embeddings
|
||||
if (t_embd && n_outputs > 0) {
|
||||
ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd);
|
||||
GGML_ASSERT(backend_embd != nullptr);
|
||||
|
||||
switch (cparams.pooling_type) {
|
||||
case LLAMA_POOLING_TYPE_NONE:
|
||||
{
|
||||
// extract token embeddings
|
||||
GGML_ASSERT(embd != nullptr);
|
||||
float * embd_out = embd + n_outputs_prev*n_embd;
|
||||
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
if (n_outputs) {
|
||||
GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all);
|
||||
GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float));
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_MEAN:
|
||||
case LLAMA_POOLING_TYPE_CLS:
|
||||
case LLAMA_POOLING_TYPE_LAST:
|
||||
{
|
||||
// extract sequence embeddings (cleared before processing each batch)
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_embd);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_RANK:
|
||||
{
|
||||
// extract the rerank score - n_cls_out floats per sequence
|
||||
auto & embd_seq_out = embd_seq;
|
||||
|
||||
const uint32_t n_cls_out = hparams.n_cls_out;
|
||||
|
||||
for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) {
|
||||
const llama_seq_id seq_id = ubatch.seq_id_unq[s];
|
||||
const int32_t seq_idx = ubatch.seq_idx[seq_id];
|
||||
|
||||
embd_seq_out[seq_id].resize(n_cls_out);
|
||||
ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float));
|
||||
}
|
||||
} break;
|
||||
case LLAMA_POOLING_TYPE_UNSPECIFIED:
|
||||
{
|
||||
GGML_ABORT("unknown pooling type");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1635,7 +1639,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// output
|
||||
//
|
||||
|
||||
uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||
uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) {
|
||||
const auto & hparams = model.hparams;
|
||||
const auto & vocab = model.vocab;
|
||||
|
||||
|
|
@ -1654,23 +1658,37 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
has_embd = true;
|
||||
}
|
||||
|
||||
const bool backend_sampling = !sampling.samplers.empty();
|
||||
// Check which sampling modes are needed by sequences in the current batch.
|
||||
bool batch_has_backend_sampling = false;
|
||||
bool batch_needs_cpu_logits = false;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
if (!batch.logits[i]) {
|
||||
continue;
|
||||
}
|
||||
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
||||
batch_has_backend_sampling = true;
|
||||
} else {
|
||||
batch_needs_cpu_logits = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
size_t backend_float_count = 0;
|
||||
size_t backend_token_count = 0;
|
||||
|
||||
if (!backend_sampling) {
|
||||
logits_size = has_logits ? n_vocab*n_outputs_max : 0;
|
||||
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
// Allocate CPU logits buffer only if needed by sequences in this batch
|
||||
logits_size = (has_logits && batch_needs_cpu_logits) ? n_vocab*n_outputs_max : 0;
|
||||
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||
|
||||
// reset backend sampling values.
|
||||
if (!batch_has_backend_sampling) {
|
||||
sampling.logits_size = 0;
|
||||
sampling.probs_size = 0;
|
||||
sampling.sampled_size = 0;
|
||||
sampling.candidates_size = 0;
|
||||
} else {
|
||||
logits_size = 0;
|
||||
embd_size = 0;
|
||||
|
||||
sampling.logits_size = n_vocab*n_outputs_max;
|
||||
sampling.probs_size = n_vocab*n_outputs_max;
|
||||
sampling.sampled_size = n_outputs_max;
|
||||
|
|
@ -1727,15 +1745,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
sampling.sampled = nullptr;
|
||||
sampling.candidates = nullptr;
|
||||
|
||||
if (!backend_sampling) {
|
||||
logits = has_logits ? output_base : nullptr;
|
||||
embd = has_embd ? output_base + logits_size : nullptr;
|
||||
} else {
|
||||
// Allocate worst case (full vocabulary size) for backend sampled
|
||||
// data in the pinned memory buffer.
|
||||
size_t offset = 0;
|
||||
uint8_t * base = (uint8_t *) output_base;
|
||||
size_t offset = 0;
|
||||
uint8_t * base = (uint8_t *) output_base;
|
||||
|
||||
logits = (has_logits && batch_needs_cpu_logits) ? output_base : nullptr;
|
||||
offset += logits_size * sizeof(float);
|
||||
|
||||
embd = has_embd ? (float *) (base + offset) : nullptr;
|
||||
offset += embd_size * sizeof(float);
|
||||
|
||||
if (batch_has_backend_sampling) {
|
||||
sampling.logits = (float *) (base + offset);
|
||||
offset += sampling.logits_size * sizeof(float);
|
||||
|
||||
|
|
@ -2400,7 +2419,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) {
|
|||
auto n_outputs = this->n_outputs;
|
||||
io.read_to(&n_outputs, sizeof(n_outputs));
|
||||
|
||||
if (n_outputs > output_reserve(n_outputs)) {
|
||||
// Create a dummy batch for state loading.
|
||||
llama_batch dummy_batch = {};
|
||||
dummy_batch.n_tokens = 0;
|
||||
if (n_outputs > output_reserve(n_outputs, dummy_batch)) {
|
||||
throw std::runtime_error("could not reserve outputs");
|
||||
}
|
||||
|
||||
|
|
@ -2631,7 +2653,7 @@ void llama_context::opt_epoch_iter(
|
|||
}
|
||||
|
||||
// reserve output buffer
|
||||
if (output_reserve(n_outputs_all) < n_outputs_all) {
|
||||
if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) {
|
||||
LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all);
|
||||
GGML_ABORT("TODO: handle this error");
|
||||
};
|
||||
|
|
|
|||
|
|
@ -200,7 +200,7 @@ private:
|
|||
|
||||
// Make sure enough space is available for outputs.
|
||||
// Returns max number of outputs for which space was reserved.
|
||||
uint32_t output_reserve(int32_t n_outputs);
|
||||
uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch);
|
||||
|
||||
void output_reorder();
|
||||
int64_t resolve_output_row(int32_t i) const;
|
||||
|
|
|
|||
|
|
@ -44,7 +44,7 @@ struct test_model_context {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs) {
|
||||
bool setup(const char * model_path, std::vector<llama_sampler_seq_config> & configs, int32_t n_seq_max = -1) {
|
||||
if (model == nullptr) {
|
||||
load_model(model_path);
|
||||
}
|
||||
|
|
@ -59,13 +59,18 @@ struct test_model_context {
|
|||
cparams.samplers = configs.data();
|
||||
cparams.n_samplers = configs.size();
|
||||
|
||||
int32_t max_seq_id = 0;
|
||||
for (const auto & config : configs) {
|
||||
if (config.seq_id > max_seq_id) {
|
||||
max_seq_id = config.seq_id;
|
||||
// If n_seq_max is not specified, calculate it from configs
|
||||
if (n_seq_max < 0) {
|
||||
int32_t max_seq_id = 0;
|
||||
for (const auto & config : configs) {
|
||||
if (config.seq_id > max_seq_id) {
|
||||
max_seq_id = config.seq_id;
|
||||
}
|
||||
}
|
||||
cparams.n_seq_max = max_seq_id + 1;
|
||||
} else {
|
||||
cparams.n_seq_max = n_seq_max;
|
||||
}
|
||||
cparams.n_seq_max = max_seq_id + 1;
|
||||
|
||||
ctx = llama_init_from_model(model, cparams);
|
||||
if (ctx == nullptr) {
|
||||
|
|
@ -280,7 +285,7 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Some"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
|
@ -297,7 +302,9 @@ static void test_backend_greedy_sampling(const char * model_path) {
|
|||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, loop_idx);
|
||||
printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
test_ctx.decode_token(token, 0);
|
||||
if (!test_ctx.decode_token(token, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -316,7 +323,7 @@ static void test_backend_top_k_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
|
@ -373,7 +380,7 @@ static void test_backend_temp_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfify sequence 0
|
||||
|
|
@ -431,7 +438,7 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
|
@ -464,7 +471,9 @@ static void test_backend_min_p_sampling(const char * model_path) {
|
|||
int32_t loop_idx = test_ctx.idx_for_seq(seq_id);
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx);
|
||||
printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str());
|
||||
test_ctx.decode_token(token, 0);
|
||||
if (!test_ctx.decode_token(token, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
}
|
||||
|
||||
printf("min-p sampling test PASSED\n");
|
||||
|
|
@ -499,7 +508,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
};
|
||||
|
||||
if (!test_ctx.decode(prompts)) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfiy sequence 0
|
||||
|
|
@ -535,7 +544,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) {
|
|||
|
||||
// Decode all tokens in a single batch
|
||||
if (!test_ctx.decode_tokens(tokens)) {
|
||||
break;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -557,7 +566,7 @@ static void test_backend_dist_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Some"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
|
@ -586,7 +595,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Some"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
|
@ -640,7 +649,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
|
||||
|
|
@ -650,7 +659,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) {
|
|||
}
|
||||
|
||||
// This test verifies that it is possible to have two different backend sampler,
|
||||
// one that used the backend dist sampler, and another that uses CPU dist sampler.
|
||||
// one that uses the backend dist sampler, and another that uses CPU dist sampler.
|
||||
static void test_backend_mixed_sampling(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
|
|
@ -678,7 +687,7 @@ static void test_backend_mixed_sampling(const char * model_path) {
|
|||
};
|
||||
|
||||
if (!test_ctx.decode(prompts)) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verfiy sequence 0 that used the dist backend sampler.
|
||||
|
|
@ -720,7 +729,7 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
}
|
||||
|
||||
if (!test_ctx.decode({{seq_id, "Hello"}})) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(seq_id);
|
||||
|
|
@ -741,7 +750,7 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
|
||||
std::map<llama_seq_id, llama_token> tokens = { { seq_id, backend_token}, };
|
||||
if (!test_ctx.decode_tokens(tokens)) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Should not have any sampled token or probs after clearing the backend sampler.
|
||||
|
|
@ -763,7 +772,7 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain);
|
||||
|
||||
if (!test_ctx.decode_tokens(tokens2)) {
|
||||
return;
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_token new_backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id));
|
||||
|
|
@ -771,6 +780,101 @@ static void test_backend_set_sampler(const char * model_path) {
|
|||
printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str());
|
||||
}
|
||||
|
||||
static void test_backend_cpu_mixed_batch(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
// Sequence 0 uses backend sampling
|
||||
struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0);
|
||||
llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88));
|
||||
|
||||
std::vector<llama_sampler_seq_config> backend_sampler_configs = {
|
||||
{ 0, sampler_chain_0 },
|
||||
};
|
||||
|
||||
// We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling
|
||||
if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) {
|
||||
return;
|
||||
}
|
||||
|
||||
std::map<llama_seq_id, std::string> prompts = {
|
||||
{0, "Hello"}, // Will use backend sampling
|
||||
{1, "Some"} // Will use CPU sampling
|
||||
};
|
||||
|
||||
if (!test_ctx.decode(prompts)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
// Verify sequence 0 (backend sampled)
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
// Verify sequence 1 (CPU sampled)
|
||||
{
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
|
||||
llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL);
|
||||
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
|
||||
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
// Clear/remove the backend sampler, and sample again
|
||||
{
|
||||
// clear the backend sampler for seq 0 so that there are no backend
|
||||
// samplers.
|
||||
llama_set_backend_sampler(test_ctx.ctx, 0, nullptr);
|
||||
|
||||
// Create a CPU sampler and verify we can sampler from it.
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * chain = llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(chain, llama_sampler_init_greedy());
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(1);
|
||||
llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx);
|
||||
if (!test_ctx.decode_token(token, 1)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
llama_sampler_free(chain);
|
||||
}
|
||||
|
||||
// Set a backend sampler so that we can verify that it can be reset
|
||||
{
|
||||
struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params();
|
||||
struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params);
|
||||
llama_sampler_chain_add(sampler_chain, llama_sampler_backend_init_dist(88));
|
||||
|
||||
llama_set_backend_sampler(test_ctx.ctx, 0, sampler_chain);
|
||||
|
||||
if (!test_ctx.decode_token(3834, 0)) {
|
||||
GGML_ASSERT(false && "Failed to decode token");
|
||||
}
|
||||
|
||||
int32_t batch_idx = test_ctx.idx_for_seq(0);
|
||||
llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx);
|
||||
const std::string token_str = test_ctx.token_to_piece(token, false);
|
||||
printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str());
|
||||
GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab);
|
||||
}
|
||||
|
||||
printf("backend-cpu mixed batch test PASSED\n");
|
||||
}
|
||||
|
||||
static void test_backend_max_outputs(const char * model_path) {
|
||||
test_model_context test_ctx;
|
||||
|
||||
|
|
@ -829,6 +933,7 @@ static const backend_test_case BACKEND_TESTS[] = {
|
|||
{ "max_outputs", test_backend_max_outputs, true },
|
||||
{ "mixed", test_backend_mixed_sampling, true },
|
||||
{ "min_p", test_backend_min_p_sampling, true },
|
||||
{ "cpu_mixed", test_backend_cpu_mixed_batch, true },
|
||||
};
|
||||
|
||||
struct backend_cli_args {
|
||||
|
|
|
|||
Loading…
Reference in New Issue