sampling : introduce sampling_info struct
This commit introduces a sampling_info struct to encapsulate all backend sampling related data within the llama_context class. It also updates to use more descriptive names for sampled tokens and candidates in the backend sampler ggml data structure.
This commit is contained in:
parent
ed4345bdd9
commit
0d28b16bdc
|
|
@ -213,8 +213,8 @@ extern "C" {
|
|||
struct llama_sampler_ggml_data {
|
||||
struct ggml_tensor * logits;
|
||||
struct ggml_tensor * probs;
|
||||
struct ggml_tensor * sampled_token;
|
||||
struct ggml_tensor * filtered_ids;
|
||||
struct ggml_tensor * sampled;
|
||||
struct ggml_tensor * candidates;
|
||||
};
|
||||
|
||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||
|
|
|
|||
|
|
@ -15,7 +15,7 @@ static void llama_sampler_backend_greedy_apply_ggml(
|
|||
GGML_UNUSED(smpl);
|
||||
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
|
||||
ggml_set_name(argmax_result, "argmax_result");
|
||||
ggml_data->sampled_token = argmax_result;
|
||||
ggml_data->sampled = argmax_result;
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) {
|
||||
|
|
@ -149,7 +149,7 @@ static void llama_sampler_backend_top_k_apply_ggml(
|
|||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
||||
}
|
||||
|
||||
ggml_data->filtered_ids = top_k;
|
||||
ggml_data->candidates = top_k;
|
||||
|
||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
|
||||
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
||||
|
|
@ -303,19 +303,19 @@ static void llama_sampler_backend_dist_apply_ggml(
|
|||
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
||||
ggml_set_name(idx, "dist_index_i32");
|
||||
|
||||
// Map back to original vocab ids if a filtered id tensor is available.
|
||||
// Map back to original vocab ids if a candidates tensor is available.
|
||||
struct ggml_tensor * sampled_token = idx;
|
||||
if (ggml_data->filtered_ids != nullptr) {
|
||||
struct ggml_tensor * filtered_ids = ggml_data->filtered_ids;
|
||||
struct ggml_tensor * filtered_ids_reshaped = ggml_view_2d(ctx, filtered_ids, 1, ggml_nelements(filtered_ids),
|
||||
ggml_type_size(filtered_ids->type), 0);
|
||||
if (ggml_data->candidates != nullptr) {
|
||||
struct ggml_tensor * candidates = ggml_data->candidates;
|
||||
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
|
||||
ggml_type_size(candidates->type), 0);
|
||||
|
||||
sampled_token = ggml_get_rows(ctx, filtered_ids_reshaped, idx);
|
||||
sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
|
||||
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||
}
|
||||
|
||||
ggml_set_output(sampled_token);
|
||||
ggml_data->sampled_token = sampled_token;
|
||||
ggml_data->sampled = sampled_token;
|
||||
}
|
||||
|
||||
static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) {
|
||||
|
|
|
|||
|
|
@ -60,11 +60,11 @@ llama_context::llama_context(
|
|||
|
||||
// backend samplers
|
||||
if (params.samplers != nullptr && params.n_samplers > 0) {
|
||||
samplers.reserve(params.n_samplers);
|
||||
sampling.samplers.reserve(params.n_samplers);
|
||||
|
||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||
const auto & config = params.samplers[i];
|
||||
samplers[config.seq_id] = config.sampler;
|
||||
sampling.samplers[config.seq_id] = config.sampler;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -435,9 +435,9 @@ llama_context::llama_context(
|
|||
{
|
||||
const llama_vocab * vocab = llama_model_get_vocab(&model);
|
||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||
sampled_token_ids_full_vocab.resize(n_vocab);
|
||||
sampling.token_ids_full_vocab.resize(n_vocab);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sampled_token_ids_full_vocab[i] = i;
|
||||
sampling.token_ids_full_vocab[i] = i;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -445,7 +445,7 @@ llama_context::llama_context(
|
|||
llama_context::~llama_context() {
|
||||
ggml_opt_free(opt_ctx);
|
||||
// TODO: perhaps use a smart pointer for samplers
|
||||
for (auto const& [seq_id, sampler] : samplers) {
|
||||
for (auto const& [seq_id, sampler] : sampling.samplers) {
|
||||
llama_sampler_free(sampler);
|
||||
}
|
||||
}
|
||||
|
|
@ -635,7 +635,7 @@ float * llama_context::get_embeddings() {
|
|||
}
|
||||
|
||||
llama_token * llama_context::get_backend_sampled_tokens() {
|
||||
return sampled_tokens;
|
||||
return sampling.sampled;
|
||||
}
|
||||
|
||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||
|
|
@ -691,15 +691,15 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
|
|||
// Handle special case where idx == -1 (single sequence exists) which is
|
||||
// a valid index when using common_sampler_sample.
|
||||
if (idx == -1) {
|
||||
if (sampled_tokens_map.size() == 1) {
|
||||
auto it = sampled_tokens_map.begin();
|
||||
if (sampling.map_sampled.size() == 1) {
|
||||
auto it = sampling.map_sampled.begin();
|
||||
return it->second;
|
||||
}
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
auto it = sampled_tokens_map.find(idx);
|
||||
if (it == sampled_tokens_map.end()) {
|
||||
auto it = sampling.map_sampled.find(idx);
|
||||
if (it == sampling.map_sampled.end()) {
|
||||
return LLAMA_TOKEN_NULL;
|
||||
}
|
||||
|
||||
|
|
@ -708,13 +708,13 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
|
|||
|
||||
float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_probs_map.size() == 1) {
|
||||
return sampled_probs_map.begin()->second.data();
|
||||
if (sampling.map_probs.size() == 1) {
|
||||
return sampling.map_probs.begin()->second.data();
|
||||
}
|
||||
}
|
||||
|
||||
auto it = sampled_probs_map.find(idx);
|
||||
if (it == sampled_probs_map.end()) {
|
||||
auto it = sampling.map_probs.find(idx);
|
||||
if (it == sampling.map_probs.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -723,12 +723,12 @@ float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
|
|||
|
||||
float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_logits_map.size() == 1) {
|
||||
return sampled_logits_map.begin()->second.data();
|
||||
if (sampling.map_logits.size() == 1) {
|
||||
return sampling.map_logits.begin()->second.data();
|
||||
}
|
||||
}
|
||||
auto it = sampled_logits_map.find(idx);
|
||||
if (it == sampled_logits_map.end()) {
|
||||
auto it = sampling.map_logits.find(idx);
|
||||
if (it == sampling.map_logits.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
|
|
@ -737,29 +737,29 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
|||
|
||||
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
|
||||
if (idx == -1) {
|
||||
if (sampled_token_ids_map.size() == 1) {
|
||||
const auto & vec = sampled_token_ids_map.begin()->second;
|
||||
if (sampling.map_cadidates.size() == 1) {
|
||||
const auto & vec = sampling.map_cadidates.begin()->second;
|
||||
if (!vec.empty()) {
|
||||
return vec.data();
|
||||
}
|
||||
}
|
||||
}
|
||||
auto it = sampled_token_ids_map.find(idx);
|
||||
if (it != sampled_token_ids_map.end() && !it->second.empty()) {
|
||||
auto it = sampling.map_cadidates.find(idx);
|
||||
if (it != sampling.map_cadidates.end() && !it->second.empty()) {
|
||||
return it->second.data();
|
||||
}
|
||||
|
||||
return sampled_token_ids_full_vocab.data();
|
||||
return sampling.token_ids_full_vocab.data();
|
||||
}
|
||||
|
||||
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
||||
if (idx == -1) {
|
||||
if (sampled_logits_map.size() == 1) {
|
||||
return sampled_logits_map.begin()->second.size();
|
||||
if (sampling.map_logits.size() == 1) {
|
||||
return sampling.map_logits.begin()->second.size();
|
||||
}
|
||||
}
|
||||
auto it = sampled_logits_map.find(idx);
|
||||
if (it == sampled_logits_map.end()) {
|
||||
auto it = sampling.map_logits.find(idx);
|
||||
if (it == sampling.map_logits.end()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -768,14 +768,14 @@ size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
|||
|
||||
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const {
|
||||
if (idx == -1) {
|
||||
if (sampled_probs_map.size() == 1) {
|
||||
return sampled_probs_map.begin()->second.size();
|
||||
if (sampling.map_probs.size() == 1) {
|
||||
return sampling.map_probs.begin()->second.size();
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
auto it = sampled_probs_map.find(idx);
|
||||
if (it == sampled_probs_map.end()) {
|
||||
auto it = sampling.map_probs.find(idx);
|
||||
if (it == sampling.map_probs.end()) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
|
|
@ -841,8 +841,8 @@ void llama_context::set_warmup(bool value) {
|
|||
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||
|
||||
auto it = samplers.find(seq_id);
|
||||
if (it != samplers.end()) {
|
||||
auto it = sampling.samplers.find(seq_id);
|
||||
if (it != sampling.samplers.end()) {
|
||||
// If the sampler to be set is the same that is already set, do nothing.
|
||||
if (it->second == sampler) {
|
||||
return;
|
||||
|
|
@ -853,7 +853,7 @@ void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sam
|
|||
// If sampler is nullptr, we remove the samppler chain for this seq_id.
|
||||
// chain for this seq_id.
|
||||
if (sampler == nullptr) {
|
||||
samplers.erase(it);
|
||||
sampling.samplers.erase(it);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -865,7 +865,7 @@ void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sam
|
|||
// If there is no sampler for this seq_id and the caller provides a non-null
|
||||
// sampler, we set it.
|
||||
if (sampler != nullptr) {
|
||||
samplers[seq_id] = sampler;
|
||||
sampling.samplers[seq_id] = sampler;
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1202,7 +1202,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// when computing embeddings, all tokens are output
|
||||
const bool output_all = cparams.embeddings;
|
||||
const bool has_backend_samplers = !samplers.empty();
|
||||
const bool has_backend_samplers = !sampling.samplers.empty();
|
||||
|
||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||
|
|
@ -1235,10 +1235,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
|
||||
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||
embd_seq.clear();
|
||||
sampled_probs_map.clear();
|
||||
sampled_logits_map.clear();
|
||||
sampled_tokens_map.clear();
|
||||
sampled_token_ids_map.clear();
|
||||
sampling.map_probs.clear();
|
||||
sampling.map_logits.clear();
|
||||
sampling.map_sampled.clear();
|
||||
sampling.map_cadidates.clear();
|
||||
output_swaps.clear();
|
||||
|
||||
bool did_optimize = false;
|
||||
|
|
@ -1361,27 +1361,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
|||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||
//}
|
||||
|
||||
backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||
backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||
|
||||
if (has_backend_samplers && backend_has_sampled) {
|
||||
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
||||
|
||||
// If a backend sampler has sampled a token we only want to copy the
|
||||
// sampled tokens and avoid copying logits and probabilites.
|
||||
if (!res->t_sampled_tokens.empty()) {
|
||||
if (!res->t_sampled.empty()) {
|
||||
// async copy the sampled tokens from the backend to the host.
|
||||
copy_tensor_async_int(res->t_sampled_tokens, sampled_tokens_map, seq_to_batch_idx, sched.get());
|
||||
copy_tensor_async_int(res->t_sampled, sampling.map_sampled, seq_to_batch_idx, sched.get());
|
||||
} else {
|
||||
// async copy the sampled logits/probs from the backend to the host.
|
||||
copy_tensor_async_floats(res->t_sampled_logits, sampled_logits_map, seq_to_batch_idx, sched.get());
|
||||
copy_tensor_async_floats(res->t_sampled_probs, sampled_probs_map, seq_to_batch_idx, sched.get());
|
||||
copy_tensor_async_floats(res->t_sampled_logits, sampling.map_logits, seq_to_batch_idx, sched.get());
|
||||
copy_tensor_async_floats(res->t_sampled_probs, sampling.map_probs, seq_to_batch_idx, sched.get());
|
||||
}
|
||||
|
||||
// async copy the filtered token ids from the backend to the host.
|
||||
// async copy the candidate token ids from the backend to the host.
|
||||
// These are needed for:
|
||||
// 1) Backend dist sampler to map indices to vocab token ids.
|
||||
// 2) CPU samplers to associate filtered logits with their token ids.
|
||||
copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
|
||||
// 2) CPU samplers to associate candidate logits with their token ids.
|
||||
copy_tensor_async_token_ids(res->t_candidates, sampling.map_cadidates, seq_to_batch_idx, sched.get());
|
||||
|
||||
}
|
||||
|
||||
|
|
@ -1589,8 +1589,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
|||
|
||||
logits = has_logits ? output_base : nullptr;
|
||||
embd = has_embd ? output_base + logits_size : nullptr;
|
||||
sampled_tokens = !samplers.empty() ? s_output_base : nullptr;
|
||||
sampled_probs = !samplers.empty() ? embd : nullptr;
|
||||
|
||||
sampling.sampled = !sampling.samplers.empty() ? s_output_base : nullptr;
|
||||
sampling.probs = !sampling.samplers.empty() ? embd : nullptr;
|
||||
|
||||
// set all ids as invalid (negative)
|
||||
std::fill(output_ids.begin(), output_ids.end(), -1);
|
||||
|
|
@ -1700,7 +1701,7 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.loras =*/ &loras,
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ samplers,
|
||||
/*.samplers =*/ sampling.samplers,
|
||||
/*.n_outputs =*/ n_outputs,
|
||||
/*.cb =*/ graph_get_cb(),
|
||||
/*.res =*/ res,
|
||||
|
|
|
|||
|
|
@ -254,16 +254,21 @@ private:
|
|||
size_t logits_size = 0; // capacity (of floats) for logits
|
||||
float * logits = nullptr;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
llama_token * sampled_tokens = nullptr;
|
||||
std::unordered_map<int32_t, llama_token> sampled_tokens_map;
|
||||
struct sampling_info {
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
|
||||
float * sampled_probs = nullptr;
|
||||
std::unordered_map<int32_t, std::vector<float>> sampled_probs_map;
|
||||
llama_token * sampled = nullptr;
|
||||
float * probs = nullptr;
|
||||
|
||||
std::unordered_map<int32_t, std::vector<float>> sampled_logits_map;
|
||||
std::unordered_map<int32_t, std::vector<llama_token>> sampled_token_ids_map;
|
||||
std::vector<llama_token> sampled_token_ids_full_vocab;
|
||||
std::unordered_map<int32_t, llama_token> map_sampled;
|
||||
std::unordered_map<int32_t, std::vector<float>> map_probs;
|
||||
std::unordered_map<int32_t, std::vector<float>> map_logits;
|
||||
std::unordered_map<int32_t, std::vector<llama_token>> map_cadidates;
|
||||
|
||||
std::vector<llama_token> token_ids_full_vocab;
|
||||
};
|
||||
|
||||
sampling_info sampling;
|
||||
|
||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||
|
|
|
|||
|
|
@ -504,10 +504,10 @@ void llm_graph_result::reset() {
|
|||
t_logits = nullptr;
|
||||
t_embd = nullptr;
|
||||
t_embd_pooled = nullptr;
|
||||
t_sampled_tokens.clear();
|
||||
t_sampled.clear();
|
||||
t_sampled_probs.clear();
|
||||
t_sampled_logits.clear();
|
||||
t_sampled_token_ids.clear();
|
||||
t_candidates.clear();
|
||||
|
||||
params = {};
|
||||
|
||||
|
|
@ -2098,17 +2098,17 @@ void llm_graph_context::build_sampling() const {
|
|||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
struct llama_sampler_ggml_data ggml_data = {
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled_token =*/ nullptr,
|
||||
/*.filtered_ids =*/ nullptr,
|
||||
/*.logits =*/ logits_seq,
|
||||
/*.probs =*/ nullptr,
|
||||
/*.sampled =*/ nullptr,
|
||||
/*.candidates =*/ nullptr,
|
||||
};
|
||||
|
||||
llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data);
|
||||
|
||||
if (ggml_data.sampled_token != nullptr) {
|
||||
res->t_sampled_tokens[seq_id] = ggml_data.sampled_token;
|
||||
ggml_build_forward_expand(gf, ggml_data.sampled_token);
|
||||
if (ggml_data.sampled != nullptr) {
|
||||
res->t_sampled[seq_id] = ggml_data.sampled;
|
||||
ggml_build_forward_expand(gf, ggml_data.sampled);
|
||||
}
|
||||
|
||||
if (ggml_data.probs != nullptr) {
|
||||
|
|
@ -2121,16 +2121,16 @@ void llm_graph_context::build_sampling() const {
|
|||
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
||||
}
|
||||
|
||||
if (ggml_data.filtered_ids != nullptr) {
|
||||
res->t_sampled_token_ids[seq_id] = ggml_data.filtered_ids;
|
||||
ggml_build_forward_expand(gf, ggml_data.filtered_ids);
|
||||
if (ggml_data.candidates != nullptr) {
|
||||
res->t_candidates[seq_id] = ggml_data.candidates;
|
||||
ggml_build_forward_expand(gf, ggml_data.candidates);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
||||
/*
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
if (auto it = res->t_sampled_tokens.find(seq_id); it != res->t_sampled_tokens.end()) {
|
||||
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
||||
ggml_tensor * selected_token = it->second;
|
||||
if (selected_token != nullptr) {
|
||||
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
||||
|
|
|
|||
|
|
@ -543,8 +543,8 @@ public:
|
|||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_token_ids;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_tokens;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
|
|
|||
Loading…
Reference in New Issue