diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 094ef0481b..d8fb5d782b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -212,7 +212,10 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if (output_reserve(params.n_seq_max) < params.n_seq_max) { + // Create a dummy batch for initialization. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -1075,7 +1078,7 @@ int llama_context::encode(const llama_batch & batch_inp) { n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens) < n_tokens) { + if (output_reserve(n_tokens, batch_inp) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -1403,7 +1406,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1493,82 +1496,83 @@ int llama_context::decode(const llama_batch & batch_inp) { } - if (!backend_has_sampled) { - auto * t_logits = res->get_logits(); - auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; + auto * t_logits = res->get_logits(); + auto * t_embd = cparams.embeddings ? res->get_embd() : nullptr; - if (t_embd && res->get_embd_pooled()) { - t_embd = res->get_embd_pooled(); + if (t_embd && res->get_embd_pooled()) { + t_embd = res->get_embd_pooled(); + } + + // extract logits + // For multipsequence batches that mix backend samplers and CPU sampler + // this is currently inefficient as we copy all logits even for the + // backend sampled tokens. + if (logits && t_logits && n_outputs > 0) { + ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); + GGML_ASSERT(backend_res != nullptr); + GGML_ASSERT(logits != nullptr); + + float * logits_out = logits + n_outputs_prev*n_vocab; + + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } + } - // extract logits - if (t_logits && n_outputs > 0) { - ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); - GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + // extract embeddings + if (embd && t_embd && n_outputs > 0) { + ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); + GGML_ASSERT(backend_embd != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + switch (cparams.pooling_type) { + case LLAMA_POOLING_TYPE_NONE: + { + // extract token embeddings + GGML_ASSERT(embd != nullptr); + float * embd_out = embd + n_outputs_prev*n_embd; - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); - ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); - } - } - - // extract embeddings - if (t_embd && n_outputs > 0) { - ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); - GGML_ASSERT(backend_embd != nullptr); - - switch (cparams.pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // extract token embeddings - GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; - - if (n_outputs) { - GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_MEAN: - case LLAMA_POOLING_TYPE_CLS: - case LLAMA_POOLING_TYPE_LAST: - { - // extract sequence embeddings (cleared before processing each batch) - auto & embd_seq_out = embd_seq; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_embd); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_RANK: - { - // extract the rerank score - n_cls_out floats per sequence - auto & embd_seq_out = embd_seq; - - const uint32_t n_cls_out = hparams.n_cls_out; - - for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { - const llama_seq_id seq_id = ubatch.seq_id_unq[s]; - const int32_t seq_idx = ubatch.seq_idx[seq_id]; - - embd_seq_out[seq_id].resize(n_cls_out); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); - } - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ABORT("unknown pooling type"); + if (n_outputs) { + GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); } - } + } break; + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + // extract sequence embeddings (cleared before processing each batch) + auto & embd_seq_out = embd_seq; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_embd); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_embd*seq_idx)*sizeof(float), n_embd*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_RANK: + { + // extract the rerank score - n_cls_out floats per sequence + auto & embd_seq_out = embd_seq; + + const uint32_t n_cls_out = hparams.n_cls_out; + + for (uint32_t s = 0; s < ubatch.n_seqs_unq; ++s) { + const llama_seq_id seq_id = ubatch.seq_id_unq[s]; + const int32_t seq_idx = ubatch.seq_idx[seq_id]; + + embd_seq_out[seq_id].resize(n_cls_out); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_seq_out[seq_id].data(), (n_cls_out*seq_idx)*sizeof(float), n_cls_out*sizeof(float)); + } + } break; + case LLAMA_POOLING_TYPE_UNSPECIFIED: + { + GGML_ABORT("unknown pooling type"); + } } } @@ -1635,7 +1639,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs) { +uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1654,23 +1658,37 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - const bool backend_sampling = !sampling.samplers.empty(); + // Check which sampling modes are needed by sequences in the current batch. + bool batch_has_backend_sampling = false; + bool batch_needs_cpu_logits = false; + + for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { + batch_has_backend_sampling = true; + } else { + batch_needs_cpu_logits = true; + } + } + } + size_t backend_float_count = 0; size_t backend_token_count = 0; - if (!backend_sampling) { - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + // Allocate CPU logits buffer only if needed by sequences in this batch + logits_size = (has_logits && batch_needs_cpu_logits) ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd*n_outputs_max : 0; - // reset backend sampling values. + if (!batch_has_backend_sampling) { sampling.logits_size = 0; sampling.probs_size = 0; sampling.sampled_size = 0; sampling.candidates_size = 0; } else { - logits_size = 0; - embd_size = 0; - sampling.logits_size = n_vocab*n_outputs_max; sampling.probs_size = n_vocab*n_outputs_max; sampling.sampled_size = n_outputs_max; @@ -1727,15 +1745,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { sampling.sampled = nullptr; sampling.candidates = nullptr; - if (!backend_sampling) { - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; - } else { - // Allocate worst case (full vocabulary size) for backend sampled - // data in the pinned memory buffer. - size_t offset = 0; - uint8_t * base = (uint8_t *) output_base; + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + logits = (has_logits && batch_needs_cpu_logits) ? output_base : nullptr; + offset += logits_size * sizeof(float); + + embd = has_embd ? (float *) (base + offset) : nullptr; + offset += embd_size * sizeof(float); + + if (batch_has_backend_sampling) { sampling.logits = (float *) (base + offset); offset += sampling.logits_size * sizeof(float); @@ -2400,7 +2419,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); - if (n_outputs > output_reserve(n_outputs)) { + // Create a dummy batch for state loading. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (n_outputs > output_reserve(n_outputs, dummy_batch)) { throw std::runtime_error("could not reserve outputs"); } @@ -2631,7 +2653,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; diff --git a/src/llama-context.h b/src/llama-context.h index 2bdbf8a553..1dcd3bf419 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -200,7 +200,7 @@ private: // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs); + uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); void output_reorder(); int64_t resolve_output_row(int32_t i) const; diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index b668b88485..cd9aa003b5 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -44,7 +44,7 @@ struct test_model_context { return true; } - bool setup(const char * model_path, std::vector & configs) { + bool setup(const char * model_path, std::vector & configs, int32_t n_seq_max = -1) { if (model == nullptr) { load_model(model_path); } @@ -59,13 +59,18 @@ struct test_model_context { cparams.samplers = configs.data(); cparams.n_samplers = configs.size(); - int32_t max_seq_id = 0; - for (const auto & config : configs) { - if (config.seq_id > max_seq_id) { - max_seq_id = config.seq_id; + // If n_seq_max is not specified, calculate it from configs + if (n_seq_max < 0) { + int32_t max_seq_id = 0; + for (const auto & config : configs) { + if (config.seq_id > max_seq_id) { + max_seq_id = config.seq_id; + } } + cparams.n_seq_max = max_seq_id + 1; + } else { + cparams.n_seq_max = n_seq_max; } - cparams.n_seq_max = max_seq_id + 1; ctx = llama_init_from_model(model, cparams); if (ctx == nullptr) { @@ -280,7 +285,7 @@ static void test_backend_greedy_sampling(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Some"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); @@ -297,7 +302,9 @@ static void test_backend_greedy_sampling(const char * model_path) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, loop_idx); printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); - test_ctx.decode_token(token, 0); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } } } @@ -316,7 +323,7 @@ static void test_backend_top_k_sampling(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Hello"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); @@ -373,7 +380,7 @@ static void test_backend_temp_sampling(const char * model_path) { } if (!test_ctx.decode({{0, "Some where over"}, {1, "Once upon a"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } // Verfify sequence 0 @@ -431,7 +438,7 @@ static void test_backend_min_p_sampling(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Hello"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); @@ -464,7 +471,9 @@ static void test_backend_min_p_sampling(const char * model_path) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_sampler_sample(chain, test_ctx.ctx, loop_idx); printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); - test_ctx.decode_token(token, 0); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } } printf("min-p sampling test PASSED\n"); @@ -499,7 +508,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { }; if (!test_ctx.decode(prompts)) { - return; + GGML_ASSERT(false && "Failed to decode token"); } // Verfiy sequence 0 @@ -535,7 +544,7 @@ static void test_backend_multi_sequence_sampling(const char * model_path) { // Decode all tokens in a single batch if (!test_ctx.decode_tokens(tokens)) { - break; + GGML_ASSERT(false && "Failed to decode token"); } } @@ -557,7 +566,7 @@ static void test_backend_dist_sampling(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Some"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); @@ -586,7 +595,7 @@ static void test_backend_dist_sampling_and_cpu(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Some"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); @@ -640,7 +649,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Hello"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); @@ -650,7 +659,7 @@ static void test_backend_logit_bias_sampling(const char * model_path) { } // This test verifies that it is possible to have two different backend sampler, -// one that used the backend dist sampler, and another that uses CPU dist sampler. +// one that uses the backend dist sampler, and another that uses CPU dist sampler. static void test_backend_mixed_sampling(const char * model_path) { test_model_context test_ctx; @@ -678,7 +687,7 @@ static void test_backend_mixed_sampling(const char * model_path) { }; if (!test_ctx.decode(prompts)) { - return; + GGML_ASSERT(false && "Failed to decode token"); } // Verfiy sequence 0 that used the dist backend sampler. @@ -720,7 +729,7 @@ static void test_backend_set_sampler(const char * model_path) { } if (!test_ctx.decode({{seq_id, "Hello"}})) { - return; + GGML_ASSERT(false && "Failed to decode token"); } int32_t batch_idx = test_ctx.idx_for_seq(seq_id); @@ -741,7 +750,7 @@ static void test_backend_set_sampler(const char * model_path) { std::map tokens = { { seq_id, backend_token}, }; if (!test_ctx.decode_tokens(tokens)) { - return; + GGML_ASSERT(false && "Failed to decode token"); } // Should not have any sampled token or probs after clearing the backend sampler. @@ -763,7 +772,7 @@ static void test_backend_set_sampler(const char * model_path) { llama_set_backend_sampler(test_ctx.ctx, seq_id, new_backend_sampler_chain); if (!test_ctx.decode_tokens(tokens2)) { - return; + GGML_ASSERT(false && "Failed to decode token"); } llama_token new_backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, test_ctx.idx_for_seq(seq_id)); @@ -771,6 +780,101 @@ static void test_backend_set_sampler(const char * model_path) { printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); } +static void test_backend_cpu_mixed_batch(const char * model_path) { + test_model_context test_ctx; + + // Sequence 0 uses backend sampling + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain_0 = llama_sampler_chain_init(chain_params_0); + llama_sampler_chain_add(sampler_chain_0, llama_sampler_backend_init_dist(88)); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0 }, + }; + + // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling + if (!test_ctx.setup(model_path, backend_sampler_configs, 2)) { + return; + } + + std::map prompts = { + {0, "Hello"}, // Will use backend sampling + {1, "Some"} // Will use CPU sampling + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 (backend sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 (CPU sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + llama_token backend_token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); + + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + llama_sampler_free(chain); + } + + // Clear/remove the backend sampler, and sample again + { + // clear the backend sampler for seq 0 so that there are no backend + // samplers. + llama_set_backend_sampler(test_ctx.ctx, 0, nullptr); + + // Create a CPU sampler and verify we can sampler from it. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * chain = llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(chain, llama_sampler_init_greedy()); + + int32_t batch_idx = test_ctx.idx_for_seq(1); + llama_token token = llama_sampler_sample(chain, test_ctx.ctx, batch_idx); + if (!test_ctx.decode_token(token, 1)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_sampler_free(chain); + } + + // Set a backend sampler so that we can verify that it can be reset + { + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + struct llama_sampler * sampler_chain= llama_sampler_chain_init(chain_params); + llama_sampler_chain_add(sampler_chain, llama_sampler_backend_init_dist(88)); + + llama_set_backend_sampler(test_ctx.ctx, 0, sampler_chain); + + if (!test_ctx.decode_token(3834, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_backend_sampled_token_ith(test_ctx.ctx, batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + printf("backend-cpu mixed batch test PASSED\n"); +} + static void test_backend_max_outputs(const char * model_path) { test_model_context test_ctx; @@ -829,6 +933,7 @@ static const backend_test_case BACKEND_TESTS[] = { { "max_outputs", test_backend_max_outputs, true }, { "mixed", test_backend_mixed_sampling, true }, { "min_p", test_backend_min_p_sampling, true }, + { "cpu_mixed", test_backend_cpu_mixed_batch, true }, }; struct backend_cli_args {