diff --git a/include/llama.h b/include/llama.h index 2bc41f36d8..9c4862ad89 100644 --- a/include/llama.h +++ b/include/llama.h @@ -213,8 +213,8 @@ extern "C" { struct llama_sampler_ggml_data { struct ggml_tensor * logits; struct ggml_tensor * probs; - struct ggml_tensor * sampled_token; - struct ggml_tensor * filtered_ids; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; }; typedef bool (*llama_progress_callback)(float progress, void * user_data); diff --git a/src/llama-backend-sampler.cpp b/src/llama-backend-sampler.cpp index 42c8d85aeb..cd6b8bb752 100644 --- a/src/llama-backend-sampler.cpp +++ b/src/llama-backend-sampler.cpp @@ -15,7 +15,7 @@ static void llama_sampler_backend_greedy_apply_ggml( GGML_UNUSED(smpl); struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits); ggml_set_name(argmax_result, "argmax_result"); - ggml_data->sampled_token = argmax_result; + ggml_data->sampled = argmax_result; } static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) { @@ -149,7 +149,7 @@ static void llama_sampler_backend_top_k_apply_ggml( fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n"); } - ggml_data->filtered_ids = top_k; + ggml_data->candidates = top_k; struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]); struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); @@ -303,19 +303,19 @@ static void llama_sampler_backend_dist_apply_ggml( struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); ggml_set_name(idx, "dist_index_i32"); - // Map back to original vocab ids if a filtered id tensor is available. + // Map back to original vocab ids if a candidates tensor is available. struct ggml_tensor * sampled_token = idx; - if (ggml_data->filtered_ids != nullptr) { - struct ggml_tensor * filtered_ids = ggml_data->filtered_ids; - struct ggml_tensor * filtered_ids_reshaped = ggml_view_2d(ctx, filtered_ids, 1, ggml_nelements(filtered_ids), - ggml_type_size(filtered_ids->type), 0); + if (ggml_data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_data->candidates; + struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates), + ggml_type_size(candidates->type), 0); - sampled_token = ggml_get_rows(ctx, filtered_ids_reshaped, idx); + sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx); ggml_set_name(sampled_token, "dist_sampled_token"); } ggml_set_output(sampled_token); - ggml_data->sampled_token = sampled_token; + ggml_data->sampled = sampled_token; } static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) { diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 5868f6246f..7bebf58b9e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -60,11 +60,11 @@ llama_context::llama_context( // backend samplers if (params.samplers != nullptr && params.n_samplers > 0) { - samplers.reserve(params.n_samplers); + sampling.samplers.reserve(params.n_samplers); for (size_t i = 0; i < params.n_samplers; ++i) { const auto & config = params.samplers[i]; - samplers[config.seq_id] = config.sampler; + sampling.samplers[config.seq_id] = config.sampler; } } @@ -435,9 +435,9 @@ llama_context::llama_context( { const llama_vocab * vocab = llama_model_get_vocab(&model); const int n_vocab = llama_vocab_n_tokens(vocab); - sampled_token_ids_full_vocab.resize(n_vocab); + sampling.token_ids_full_vocab.resize(n_vocab); for (int i = 0; i < n_vocab; ++i) { - sampled_token_ids_full_vocab[i] = i; + sampling.token_ids_full_vocab[i] = i; } } } @@ -445,7 +445,7 @@ llama_context::llama_context( llama_context::~llama_context() { ggml_opt_free(opt_ctx); // TODO: perhaps use a smart pointer for samplers - for (auto const& [seq_id, sampler] : samplers) { + for (auto const& [seq_id, sampler] : sampling.samplers) { llama_sampler_free(sampler); } } @@ -635,7 +635,7 @@ float * llama_context::get_embeddings() { } llama_token * llama_context::get_backend_sampled_tokens() { - return sampled_tokens; + return sampling.sampled; } float * llama_context::get_embeddings_ith(int32_t i) { @@ -691,15 +691,15 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) { // Handle special case where idx == -1 (single sequence exists) which is // a valid index when using common_sampler_sample. if (idx == -1) { - if (sampled_tokens_map.size() == 1) { - auto it = sampled_tokens_map.begin(); + if (sampling.map_sampled.size() == 1) { + auto it = sampling.map_sampled.begin(); return it->second; } return LLAMA_TOKEN_NULL; } - auto it = sampled_tokens_map.find(idx); - if (it == sampled_tokens_map.end()) { + auto it = sampling.map_sampled.find(idx); + if (it == sampling.map_sampled.end()) { return LLAMA_TOKEN_NULL; } @@ -708,13 +708,13 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) { float * llama_context::get_backend_sampled_probs_ith(int32_t idx) { if (idx == -1) { - if (sampled_probs_map.size() == 1) { - return sampled_probs_map.begin()->second.data(); + if (sampling.map_probs.size() == 1) { + return sampling.map_probs.begin()->second.data(); } } - auto it = sampled_probs_map.find(idx); - if (it == sampled_probs_map.end()) { + auto it = sampling.map_probs.find(idx); + if (it == sampling.map_probs.end()) { return nullptr; } @@ -723,12 +723,12 @@ float * llama_context::get_backend_sampled_probs_ith(int32_t idx) { float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { if (idx == -1) { - if (sampled_logits_map.size() == 1) { - return sampled_logits_map.begin()->second.data(); + if (sampling.map_logits.size() == 1) { + return sampling.map_logits.begin()->second.data(); } } - auto it = sampled_logits_map.find(idx); - if (it == sampled_logits_map.end()) { + auto it = sampling.map_logits.find(idx); + if (it == sampling.map_logits.end()) { return nullptr; } @@ -737,29 +737,29 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) { const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) { if (idx == -1) { - if (sampled_token_ids_map.size() == 1) { - const auto & vec = sampled_token_ids_map.begin()->second; + if (sampling.map_cadidates.size() == 1) { + const auto & vec = sampling.map_cadidates.begin()->second; if (!vec.empty()) { return vec.data(); } } } - auto it = sampled_token_ids_map.find(idx); - if (it != sampled_token_ids_map.end() && !it->second.empty()) { + auto it = sampling.map_cadidates.find(idx); + if (it != sampling.map_cadidates.end() && !it->second.empty()) { return it->second.data(); } - return sampled_token_ids_full_vocab.data(); + return sampling.token_ids_full_vocab.data(); } size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const { if (idx == -1) { - if (sampled_logits_map.size() == 1) { - return sampled_logits_map.begin()->second.size(); + if (sampling.map_logits.size() == 1) { + return sampling.map_logits.begin()->second.size(); } } - auto it = sampled_logits_map.find(idx); - if (it == sampled_logits_map.end()) { + auto it = sampling.map_logits.find(idx); + if (it == sampling.map_logits.end()) { return 0; } @@ -768,14 +768,14 @@ size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const { size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const { if (idx == -1) { - if (sampled_probs_map.size() == 1) { - return sampled_probs_map.begin()->second.size(); + if (sampling.map_probs.size() == 1) { + return sampling.map_probs.begin()->second.size(); } return 0; } - auto it = sampled_probs_map.find(idx); - if (it == sampled_probs_map.end()) { + auto it = sampling.map_probs.find(idx); + if (it == sampling.map_probs.end()) { return 0; } @@ -841,8 +841,8 @@ void llama_context::set_warmup(bool value) { void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) { LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); - auto it = samplers.find(seq_id); - if (it != samplers.end()) { + auto it = sampling.samplers.find(seq_id); + if (it != sampling.samplers.end()) { // If the sampler to be set is the same that is already set, do nothing. if (it->second == sampler) { return; @@ -853,7 +853,7 @@ void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sam // If sampler is nullptr, we remove the samppler chain for this seq_id. // chain for this seq_id. if (sampler == nullptr) { - samplers.erase(it); + sampling.samplers.erase(it); return; } @@ -865,7 +865,7 @@ void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sam // If there is no sampler for this seq_id and the caller provides a non-null // sampler, we set it. if (sampler != nullptr) { - samplers[seq_id] = sampler; + sampling.samplers[seq_id] = sampler; } } @@ -1202,7 +1202,7 @@ int llama_context::decode(const llama_batch & batch_inp) { // when computing embeddings, all tokens are output const bool output_all = cparams.embeddings; - const bool has_backend_samplers = !samplers.empty(); + const bool has_backend_samplers = !sampling.samplers.empty(); if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, @@ -1235,10 +1235,10 @@ int llama_context::decode(const llama_batch & batch_inp) { // TODO: this clear of the buffer can easily be forgotten - need something better embd_seq.clear(); - sampled_probs_map.clear(); - sampled_logits_map.clear(); - sampled_tokens_map.clear(); - sampled_token_ids_map.clear(); + sampling.map_probs.clear(); + sampling.map_logits.clear(); + sampling.map_sampled.clear(); + sampling.map_cadidates.clear(); output_swaps.clear(); bool did_optimize = false; @@ -1361,27 +1361,27 @@ int llama_context::decode(const llama_batch & batch_inp) { // ggml_graph_dump_dot(gf, NULL, "llama.dot"); //} - backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); if (has_backend_samplers && backend_has_sampled) { const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch); // If a backend sampler has sampled a token we only want to copy the // sampled tokens and avoid copying logits and probabilites. - if (!res->t_sampled_tokens.empty()) { + if (!res->t_sampled.empty()) { // async copy the sampled tokens from the backend to the host. - copy_tensor_async_int(res->t_sampled_tokens, sampled_tokens_map, seq_to_batch_idx, sched.get()); + copy_tensor_async_int(res->t_sampled, sampling.map_sampled, seq_to_batch_idx, sched.get()); } else { // async copy the sampled logits/probs from the backend to the host. - copy_tensor_async_floats(res->t_sampled_logits, sampled_logits_map, seq_to_batch_idx, sched.get()); - copy_tensor_async_floats(res->t_sampled_probs, sampled_probs_map, seq_to_batch_idx, sched.get()); + copy_tensor_async_floats(res->t_sampled_logits, sampling.map_logits, seq_to_batch_idx, sched.get()); + copy_tensor_async_floats(res->t_sampled_probs, sampling.map_probs, seq_to_batch_idx, sched.get()); } - // async copy the filtered token ids from the backend to the host. + // async copy the candidate token ids from the backend to the host. // These are needed for: // 1) Backend dist sampler to map indices to vocab token ids. - // 2) CPU samplers to associate filtered logits with their token ids. - copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get()); + // 2) CPU samplers to associate candidate logits with their token ids. + copy_tensor_async_token_ids(res->t_candidates, sampling.map_cadidates, seq_to_batch_idx, sched.get()); } @@ -1589,8 +1589,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { logits = has_logits ? output_base : nullptr; embd = has_embd ? output_base + logits_size : nullptr; - sampled_tokens = !samplers.empty() ? s_output_base : nullptr; - sampled_probs = !samplers.empty() ? embd : nullptr; + + sampling.sampled = !sampling.samplers.empty() ? s_output_base : nullptr; + sampling.probs = !sampling.samplers.empty() ? embd : nullptr; // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1700,7 +1701,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, - /*.samplers =*/ samplers, + /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, diff --git a/src/llama-context.h b/src/llama-context.h index aba62e6e38..8e6a111e61 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -254,16 +254,21 @@ private: size_t logits_size = 0; // capacity (of floats) for logits float * logits = nullptr; - std::unordered_map samplers; - llama_token * sampled_tokens = nullptr; - std::unordered_map sampled_tokens_map; + struct sampling_info { + std::unordered_map samplers; - float * sampled_probs = nullptr; - std::unordered_map> sampled_probs_map; + llama_token * sampled = nullptr; + float * probs = nullptr; - std::unordered_map> sampled_logits_map; - std::unordered_map> sampled_token_ids_map; - std::vector sampled_token_ids_full_vocab; + std::unordered_map map_sampled; + std::unordered_map> map_probs; + std::unordered_map> map_logits; + std::unordered_map> map_cadidates; + + std::vector token_ids_full_vocab; + }; + + sampling_info sampling; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index fbf4475aa4..8af9188d05 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -504,10 +504,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; - t_sampled_tokens.clear(); + t_sampled.clear(); t_sampled_probs.clear(); t_sampled_logits.clear(); - t_sampled_token_ids.clear(); + t_candidates.clear(); params = {}; @@ -2098,17 +2098,17 @@ void llm_graph_context::build_sampling() const { ggml_format_name(logits_seq, "logits_seq_%d", seq_id); struct llama_sampler_ggml_data ggml_data = { - /*.logits =*/ logits_seq, - /*.probs =*/ nullptr, - /*.sampled_token =*/ nullptr, - /*.filtered_ids =*/ nullptr, + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, }; llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data); - if (ggml_data.sampled_token != nullptr) { - res->t_sampled_tokens[seq_id] = ggml_data.sampled_token; - ggml_build_forward_expand(gf, ggml_data.sampled_token); + if (ggml_data.sampled != nullptr) { + res->t_sampled[seq_id] = ggml_data.sampled; + ggml_build_forward_expand(gf, ggml_data.sampled); } if (ggml_data.probs != nullptr) { @@ -2121,16 +2121,16 @@ void llm_graph_context::build_sampling() const { ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]); } - if (ggml_data.filtered_ids != nullptr) { - res->t_sampled_token_ids[seq_id] = ggml_data.filtered_ids; - ggml_build_forward_expand(gf, ggml_data.filtered_ids); + if (ggml_data.candidates != nullptr) { + res->t_candidates[seq_id] = ggml_data.candidates; + ggml_build_forward_expand(gf, ggml_data.candidates); } } // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. /* for (const auto & [seq_id, sampler] : samplers) { - if (auto it = res->t_sampled_tokens.find(seq_id); it != res->t_sampled_tokens.end()) { + if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) { ggml_tensor * selected_token = it->second; if (selected_token != nullptr) { llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); diff --git a/src/llama-graph.h b/src/llama-graph.h index f7508046e4..6797d78a20 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -543,8 +543,8 @@ public: ggml_tensor * t_embd_pooled = nullptr; std::unordered_map t_sampled_logits; - std::unordered_map t_sampled_token_ids; - std::unordered_map t_sampled_tokens; + std::unordered_map t_candidates; + std::unordered_map t_sampled; std::unordered_map t_sampled_probs; std::vector inputs;