sampling : introduce sampling_info struct
This commit introduces a sampling_info struct to encapsulate all backend sampling related data within the llama_context class. It also updates to use more descriptive names for sampled tokens and candidates in the backend sampler ggml data structure.
This commit is contained in:
parent
ed4345bdd9
commit
0d28b16bdc
|
|
@ -213,8 +213,8 @@ extern "C" {
|
||||||
struct llama_sampler_ggml_data {
|
struct llama_sampler_ggml_data {
|
||||||
struct ggml_tensor * logits;
|
struct ggml_tensor * logits;
|
||||||
struct ggml_tensor * probs;
|
struct ggml_tensor * probs;
|
||||||
struct ggml_tensor * sampled_token;
|
struct ggml_tensor * sampled;
|
||||||
struct ggml_tensor * filtered_ids;
|
struct ggml_tensor * candidates;
|
||||||
};
|
};
|
||||||
|
|
||||||
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
typedef bool (*llama_progress_callback)(float progress, void * user_data);
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ static void llama_sampler_backend_greedy_apply_ggml(
|
||||||
GGML_UNUSED(smpl);
|
GGML_UNUSED(smpl);
|
||||||
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
|
struct ggml_tensor * argmax_result = ggml_argmax(ctx, ggml_data->logits);
|
||||||
ggml_set_name(argmax_result, "argmax_result");
|
ggml_set_name(argmax_result, "argmax_result");
|
||||||
ggml_data->sampled_token = argmax_result;
|
ggml_data->sampled = argmax_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) {
|
static const char * llama_sampler_backend_greedy_sampler_name(const struct llama_sampler *) {
|
||||||
|
|
@ -149,7 +149,7 @@ static void llama_sampler_backend_top_k_apply_ggml(
|
||||||
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
fprintf(stderr, "CPU backend will be used instead which defeats the purpose of having backend samplers\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_data->filtered_ids = top_k;
|
ggml_data->candidates = top_k;
|
||||||
|
|
||||||
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
|
struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, ggml_data->logits, 1, ggml_data->logits->ne[0]);
|
||||||
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k);
|
||||||
|
|
@ -303,19 +303,19 @@ static void llama_sampler_backend_dist_apply_ggml(
|
||||||
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32);
|
||||||
ggml_set_name(idx, "dist_index_i32");
|
ggml_set_name(idx, "dist_index_i32");
|
||||||
|
|
||||||
// Map back to original vocab ids if a filtered id tensor is available.
|
// Map back to original vocab ids if a candidates tensor is available.
|
||||||
struct ggml_tensor * sampled_token = idx;
|
struct ggml_tensor * sampled_token = idx;
|
||||||
if (ggml_data->filtered_ids != nullptr) {
|
if (ggml_data->candidates != nullptr) {
|
||||||
struct ggml_tensor * filtered_ids = ggml_data->filtered_ids;
|
struct ggml_tensor * candidates = ggml_data->candidates;
|
||||||
struct ggml_tensor * filtered_ids_reshaped = ggml_view_2d(ctx, filtered_ids, 1, ggml_nelements(filtered_ids),
|
struct ggml_tensor * candidates_reshaped = ggml_view_2d(ctx, candidates, 1, ggml_nelements(candidates),
|
||||||
ggml_type_size(filtered_ids->type), 0);
|
ggml_type_size(candidates->type), 0);
|
||||||
|
|
||||||
sampled_token = ggml_get_rows(ctx, filtered_ids_reshaped, idx);
|
sampled_token = ggml_get_rows(ctx, candidates_reshaped, idx);
|
||||||
ggml_set_name(sampled_token, "dist_sampled_token");
|
ggml_set_name(sampled_token, "dist_sampled_token");
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_set_output(sampled_token);
|
ggml_set_output(sampled_token);
|
||||||
ggml_data->sampled_token = sampled_token;
|
ggml_data->sampled = sampled_token;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) {
|
static const char * llama_sampler_backend_dist_name(const struct llama_sampler *) {
|
||||||
|
|
|
||||||
|
|
@ -60,11 +60,11 @@ llama_context::llama_context(
|
||||||
|
|
||||||
// backend samplers
|
// backend samplers
|
||||||
if (params.samplers != nullptr && params.n_samplers > 0) {
|
if (params.samplers != nullptr && params.n_samplers > 0) {
|
||||||
samplers.reserve(params.n_samplers);
|
sampling.samplers.reserve(params.n_samplers);
|
||||||
|
|
||||||
for (size_t i = 0; i < params.n_samplers; ++i) {
|
for (size_t i = 0; i < params.n_samplers; ++i) {
|
||||||
const auto & config = params.samplers[i];
|
const auto & config = params.samplers[i];
|
||||||
samplers[config.seq_id] = config.sampler;
|
sampling.samplers[config.seq_id] = config.sampler;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -435,9 +435,9 @@ llama_context::llama_context(
|
||||||
{
|
{
|
||||||
const llama_vocab * vocab = llama_model_get_vocab(&model);
|
const llama_vocab * vocab = llama_model_get_vocab(&model);
|
||||||
const int n_vocab = llama_vocab_n_tokens(vocab);
|
const int n_vocab = llama_vocab_n_tokens(vocab);
|
||||||
sampled_token_ids_full_vocab.resize(n_vocab);
|
sampling.token_ids_full_vocab.resize(n_vocab);
|
||||||
for (int i = 0; i < n_vocab; ++i) {
|
for (int i = 0; i < n_vocab; ++i) {
|
||||||
sampled_token_ids_full_vocab[i] = i;
|
sampling.token_ids_full_vocab[i] = i;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -445,7 +445,7 @@ llama_context::llama_context(
|
||||||
llama_context::~llama_context() {
|
llama_context::~llama_context() {
|
||||||
ggml_opt_free(opt_ctx);
|
ggml_opt_free(opt_ctx);
|
||||||
// TODO: perhaps use a smart pointer for samplers
|
// TODO: perhaps use a smart pointer for samplers
|
||||||
for (auto const& [seq_id, sampler] : samplers) {
|
for (auto const& [seq_id, sampler] : sampling.samplers) {
|
||||||
llama_sampler_free(sampler);
|
llama_sampler_free(sampler);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -635,7 +635,7 @@ float * llama_context::get_embeddings() {
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token * llama_context::get_backend_sampled_tokens() {
|
llama_token * llama_context::get_backend_sampled_tokens() {
|
||||||
return sampled_tokens;
|
return sampling.sampled;
|
||||||
}
|
}
|
||||||
|
|
||||||
float * llama_context::get_embeddings_ith(int32_t i) {
|
float * llama_context::get_embeddings_ith(int32_t i) {
|
||||||
|
|
@ -691,15 +691,15 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
|
||||||
// Handle special case where idx == -1 (single sequence exists) which is
|
// Handle special case where idx == -1 (single sequence exists) which is
|
||||||
// a valid index when using common_sampler_sample.
|
// a valid index when using common_sampler_sample.
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
if (sampled_tokens_map.size() == 1) {
|
if (sampling.map_sampled.size() == 1) {
|
||||||
auto it = sampled_tokens_map.begin();
|
auto it = sampling.map_sampled.begin();
|
||||||
return it->second;
|
return it->second;
|
||||||
}
|
}
|
||||||
return LLAMA_TOKEN_NULL;
|
return LLAMA_TOKEN_NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto it = sampled_tokens_map.find(idx);
|
auto it = sampling.map_sampled.find(idx);
|
||||||
if (it == sampled_tokens_map.end()) {
|
if (it == sampling.map_sampled.end()) {
|
||||||
return LLAMA_TOKEN_NULL;
|
return LLAMA_TOKEN_NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -708,13 +708,13 @@ llama_token llama_context::get_backend_sampled_token_ith(int32_t idx) {
|
||||||
|
|
||||||
float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
|
float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
if (sampled_probs_map.size() == 1) {
|
if (sampling.map_probs.size() == 1) {
|
||||||
return sampled_probs_map.begin()->second.data();
|
return sampling.map_probs.begin()->second.data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto it = sampled_probs_map.find(idx);
|
auto it = sampling.map_probs.find(idx);
|
||||||
if (it == sampled_probs_map.end()) {
|
if (it == sampling.map_probs.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -723,12 +723,12 @@ float * llama_context::get_backend_sampled_probs_ith(int32_t idx) {
|
||||||
|
|
||||||
float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
if (sampled_logits_map.size() == 1) {
|
if (sampling.map_logits.size() == 1) {
|
||||||
return sampled_logits_map.begin()->second.data();
|
return sampling.map_logits.begin()->second.data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto it = sampled_logits_map.find(idx);
|
auto it = sampling.map_logits.find(idx);
|
||||||
if (it == sampled_logits_map.end()) {
|
if (it == sampling.map_logits.end()) {
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -737,29 +737,29 @@ float * llama_context::get_backend_sampled_logits_ith(int32_t idx) {
|
||||||
|
|
||||||
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
|
const llama_token * llama_context::get_backend_sampled_token_ids_ith(int32_t idx) {
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
if (sampled_token_ids_map.size() == 1) {
|
if (sampling.map_cadidates.size() == 1) {
|
||||||
const auto & vec = sampled_token_ids_map.begin()->second;
|
const auto & vec = sampling.map_cadidates.begin()->second;
|
||||||
if (!vec.empty()) {
|
if (!vec.empty()) {
|
||||||
return vec.data();
|
return vec.data();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto it = sampled_token_ids_map.find(idx);
|
auto it = sampling.map_cadidates.find(idx);
|
||||||
if (it != sampled_token_ids_map.end() && !it->second.empty()) {
|
if (it != sampling.map_cadidates.end() && !it->second.empty()) {
|
||||||
return it->second.data();
|
return it->second.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
return sampled_token_ids_full_vocab.data();
|
return sampling.token_ids_full_vocab.data();
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
if (sampled_logits_map.size() == 1) {
|
if (sampling.map_logits.size() == 1) {
|
||||||
return sampled_logits_map.begin()->second.size();
|
return sampling.map_logits.begin()->second.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
auto it = sampled_logits_map.find(idx);
|
auto it = sampling.map_logits.find(idx);
|
||||||
if (it == sampled_logits_map.end()) {
|
if (it == sampling.map_logits.end()) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -768,14 +768,14 @@ size_t llama_context::get_backend_sampled_logits_count(int32_t idx) const {
|
||||||
|
|
||||||
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const {
|
size_t llama_context::get_backend_sampled_probs_count(int32_t idx) const {
|
||||||
if (idx == -1) {
|
if (idx == -1) {
|
||||||
if (sampled_probs_map.size() == 1) {
|
if (sampling.map_probs.size() == 1) {
|
||||||
return sampled_probs_map.begin()->second.size();
|
return sampling.map_probs.begin()->second.size();
|
||||||
}
|
}
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto it = sampled_probs_map.find(idx);
|
auto it = sampling.map_probs.find(idx);
|
||||||
if (it == sampled_probs_map.end()) {
|
if (it == sampling.map_probs.end()) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -841,8 +841,8 @@ void llama_context::set_warmup(bool value) {
|
||||||
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sampler) {
|
||||||
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler);
|
||||||
|
|
||||||
auto it = samplers.find(seq_id);
|
auto it = sampling.samplers.find(seq_id);
|
||||||
if (it != samplers.end()) {
|
if (it != sampling.samplers.end()) {
|
||||||
// If the sampler to be set is the same that is already set, do nothing.
|
// If the sampler to be set is the same that is already set, do nothing.
|
||||||
if (it->second == sampler) {
|
if (it->second == sampler) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -853,7 +853,7 @@ void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sam
|
||||||
// If sampler is nullptr, we remove the samppler chain for this seq_id.
|
// If sampler is nullptr, we remove the samppler chain for this seq_id.
|
||||||
// chain for this seq_id.
|
// chain for this seq_id.
|
||||||
if (sampler == nullptr) {
|
if (sampler == nullptr) {
|
||||||
samplers.erase(it);
|
sampling.samplers.erase(it);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -865,7 +865,7 @@ void llama_context::set_backend_sampler(llama_seq_id seq_id, llama_sampler * sam
|
||||||
// If there is no sampler for this seq_id and the caller provides a non-null
|
// If there is no sampler for this seq_id and the caller provides a non-null
|
||||||
// sampler, we set it.
|
// sampler, we set it.
|
||||||
if (sampler != nullptr) {
|
if (sampler != nullptr) {
|
||||||
samplers[seq_id] = sampler;
|
sampling.samplers[seq_id] = sampler;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1202,7 +1202,7 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
|
|
||||||
// when computing embeddings, all tokens are output
|
// when computing embeddings, all tokens are output
|
||||||
const bool output_all = cparams.embeddings;
|
const bool output_all = cparams.embeddings;
|
||||||
const bool has_backend_samplers = !samplers.empty();
|
const bool has_backend_samplers = !sampling.samplers.empty();
|
||||||
|
|
||||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||||
|
|
@ -1235,10 +1235,10 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
|
|
||||||
// TODO: this clear of the buffer can easily be forgotten - need something better
|
// TODO: this clear of the buffer can easily be forgotten - need something better
|
||||||
embd_seq.clear();
|
embd_seq.clear();
|
||||||
sampled_probs_map.clear();
|
sampling.map_probs.clear();
|
||||||
sampled_logits_map.clear();
|
sampling.map_logits.clear();
|
||||||
sampled_tokens_map.clear();
|
sampling.map_sampled.clear();
|
||||||
sampled_token_ids_map.clear();
|
sampling.map_cadidates.clear();
|
||||||
output_swaps.clear();
|
output_swaps.clear();
|
||||||
|
|
||||||
bool did_optimize = false;
|
bool did_optimize = false;
|
||||||
|
|
@ -1361,27 +1361,27 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
backend_has_sampled = !res->t_sampled_tokens.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||||
|
|
||||||
if (has_backend_samplers && backend_has_sampled) {
|
if (has_backend_samplers && backend_has_sampled) {
|
||||||
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
const auto seq_to_batch_idx = build_seq_to_batch_idx(ubatch);
|
||||||
|
|
||||||
// If a backend sampler has sampled a token we only want to copy the
|
// If a backend sampler has sampled a token we only want to copy the
|
||||||
// sampled tokens and avoid copying logits and probabilites.
|
// sampled tokens and avoid copying logits and probabilites.
|
||||||
if (!res->t_sampled_tokens.empty()) {
|
if (!res->t_sampled.empty()) {
|
||||||
// async copy the sampled tokens from the backend to the host.
|
// async copy the sampled tokens from the backend to the host.
|
||||||
copy_tensor_async_int(res->t_sampled_tokens, sampled_tokens_map, seq_to_batch_idx, sched.get());
|
copy_tensor_async_int(res->t_sampled, sampling.map_sampled, seq_to_batch_idx, sched.get());
|
||||||
} else {
|
} else {
|
||||||
// async copy the sampled logits/probs from the backend to the host.
|
// async copy the sampled logits/probs from the backend to the host.
|
||||||
copy_tensor_async_floats(res->t_sampled_logits, sampled_logits_map, seq_to_batch_idx, sched.get());
|
copy_tensor_async_floats(res->t_sampled_logits, sampling.map_logits, seq_to_batch_idx, sched.get());
|
||||||
copy_tensor_async_floats(res->t_sampled_probs, sampled_probs_map, seq_to_batch_idx, sched.get());
|
copy_tensor_async_floats(res->t_sampled_probs, sampling.map_probs, seq_to_batch_idx, sched.get());
|
||||||
}
|
}
|
||||||
|
|
||||||
// async copy the filtered token ids from the backend to the host.
|
// async copy the candidate token ids from the backend to the host.
|
||||||
// These are needed for:
|
// These are needed for:
|
||||||
// 1) Backend dist sampler to map indices to vocab token ids.
|
// 1) Backend dist sampler to map indices to vocab token ids.
|
||||||
// 2) CPU samplers to associate filtered logits with their token ids.
|
// 2) CPU samplers to associate candidate logits with their token ids.
|
||||||
copy_tensor_async_token_ids(res->t_sampled_token_ids, sampled_token_ids_map, seq_to_batch_idx, sched.get());
|
copy_tensor_async_token_ids(res->t_candidates, sampling.map_cadidates, seq_to_batch_idx, sched.get());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1589,8 +1589,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) {
|
||||||
|
|
||||||
logits = has_logits ? output_base : nullptr;
|
logits = has_logits ? output_base : nullptr;
|
||||||
embd = has_embd ? output_base + logits_size : nullptr;
|
embd = has_embd ? output_base + logits_size : nullptr;
|
||||||
sampled_tokens = !samplers.empty() ? s_output_base : nullptr;
|
|
||||||
sampled_probs = !samplers.empty() ? embd : nullptr;
|
sampling.sampled = !sampling.samplers.empty() ? s_output_base : nullptr;
|
||||||
|
sampling.probs = !sampling.samplers.empty() ? embd : nullptr;
|
||||||
|
|
||||||
// set all ids as invalid (negative)
|
// set all ids as invalid (negative)
|
||||||
std::fill(output_ids.begin(), output_ids.end(), -1);
|
std::fill(output_ids.begin(), output_ids.end(), -1);
|
||||||
|
|
@ -1700,7 +1701,7 @@ llm_graph_params llama_context::graph_params(
|
||||||
/*.loras =*/ &loras,
|
/*.loras =*/ &loras,
|
||||||
/*.mctx =*/ mctx,
|
/*.mctx =*/ mctx,
|
||||||
/*.cross =*/ &cross,
|
/*.cross =*/ &cross,
|
||||||
/*.samplers =*/ samplers,
|
/*.samplers =*/ sampling.samplers,
|
||||||
/*.n_outputs =*/ n_outputs,
|
/*.n_outputs =*/ n_outputs,
|
||||||
/*.cb =*/ graph_get_cb(),
|
/*.cb =*/ graph_get_cb(),
|
||||||
/*.res =*/ res,
|
/*.res =*/ res,
|
||||||
|
|
|
||||||
|
|
@ -254,16 +254,21 @@ private:
|
||||||
size_t logits_size = 0; // capacity (of floats) for logits
|
size_t logits_size = 0; // capacity (of floats) for logits
|
||||||
float * logits = nullptr;
|
float * logits = nullptr;
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
struct sampling_info {
|
||||||
llama_token * sampled_tokens = nullptr;
|
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||||
std::unordered_map<int32_t, llama_token> sampled_tokens_map;
|
|
||||||
|
|
||||||
float * sampled_probs = nullptr;
|
llama_token * sampled = nullptr;
|
||||||
std::unordered_map<int32_t, std::vector<float>> sampled_probs_map;
|
float * probs = nullptr;
|
||||||
|
|
||||||
std::unordered_map<int32_t, std::vector<float>> sampled_logits_map;
|
std::unordered_map<int32_t, llama_token> map_sampled;
|
||||||
std::unordered_map<int32_t, std::vector<llama_token>> sampled_token_ids_map;
|
std::unordered_map<int32_t, std::vector<float>> map_probs;
|
||||||
std::vector<llama_token> sampled_token_ids_full_vocab;
|
std::unordered_map<int32_t, std::vector<float>> map_logits;
|
||||||
|
std::unordered_map<int32_t, std::vector<llama_token>> map_cadidates;
|
||||||
|
|
||||||
|
std::vector<llama_token> token_ids_full_vocab;
|
||||||
|
};
|
||||||
|
|
||||||
|
sampling_info sampling;
|
||||||
|
|
||||||
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
// embeddings output (2-dimensional array: [n_outputs][n_embd])
|
||||||
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
// populated only when pooling_type == LLAMA_POOLING_TYPE_NONE
|
||||||
|
|
|
||||||
|
|
@ -504,10 +504,10 @@ void llm_graph_result::reset() {
|
||||||
t_logits = nullptr;
|
t_logits = nullptr;
|
||||||
t_embd = nullptr;
|
t_embd = nullptr;
|
||||||
t_embd_pooled = nullptr;
|
t_embd_pooled = nullptr;
|
||||||
t_sampled_tokens.clear();
|
t_sampled.clear();
|
||||||
t_sampled_probs.clear();
|
t_sampled_probs.clear();
|
||||||
t_sampled_logits.clear();
|
t_sampled_logits.clear();
|
||||||
t_sampled_token_ids.clear();
|
t_candidates.clear();
|
||||||
|
|
||||||
params = {};
|
params = {};
|
||||||
|
|
||||||
|
|
@ -2098,17 +2098,17 @@ void llm_graph_context::build_sampling() const {
|
||||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||||
|
|
||||||
struct llama_sampler_ggml_data ggml_data = {
|
struct llama_sampler_ggml_data ggml_data = {
|
||||||
/*.logits =*/ logits_seq,
|
/*.logits =*/ logits_seq,
|
||||||
/*.probs =*/ nullptr,
|
/*.probs =*/ nullptr,
|
||||||
/*.sampled_token =*/ nullptr,
|
/*.sampled =*/ nullptr,
|
||||||
/*.filtered_ids =*/ nullptr,
|
/*.candidates =*/ nullptr,
|
||||||
};
|
};
|
||||||
|
|
||||||
llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data);
|
llama_sampler_apply_ggml(sampler, ctx0, gf, &ggml_data);
|
||||||
|
|
||||||
if (ggml_data.sampled_token != nullptr) {
|
if (ggml_data.sampled != nullptr) {
|
||||||
res->t_sampled_tokens[seq_id] = ggml_data.sampled_token;
|
res->t_sampled[seq_id] = ggml_data.sampled;
|
||||||
ggml_build_forward_expand(gf, ggml_data.sampled_token);
|
ggml_build_forward_expand(gf, ggml_data.sampled);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_data.probs != nullptr) {
|
if (ggml_data.probs != nullptr) {
|
||||||
|
|
@ -2121,16 +2121,16 @@ void llm_graph_context::build_sampling() const {
|
||||||
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
ggml_build_forward_expand(gf, res->t_sampled_logits[seq_id]);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ggml_data.filtered_ids != nullptr) {
|
if (ggml_data.candidates != nullptr) {
|
||||||
res->t_sampled_token_ids[seq_id] = ggml_data.filtered_ids;
|
res->t_candidates[seq_id] = ggml_data.candidates;
|
||||||
ggml_build_forward_expand(gf, ggml_data.filtered_ids);
|
ggml_build_forward_expand(gf, ggml_data.candidates);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
// TODO: Call llama_sampler_accept_ggml after all samplers have been applied.
|
||||||
/*
|
/*
|
||||||
for (const auto & [seq_id, sampler] : samplers) {
|
for (const auto & [seq_id, sampler] : samplers) {
|
||||||
if (auto it = res->t_sampled_tokens.find(seq_id); it != res->t_sampled_tokens.end()) {
|
if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) {
|
||||||
ggml_tensor * selected_token = it->second;
|
ggml_tensor * selected_token = it->second;
|
||||||
if (selected_token != nullptr) {
|
if (selected_token != nullptr) {
|
||||||
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token);
|
||||||
|
|
|
||||||
|
|
@ -543,8 +543,8 @@ public:
|
||||||
ggml_tensor * t_embd_pooled = nullptr;
|
ggml_tensor * t_embd_pooled = nullptr;
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_token_ids;
|
std::unordered_map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_tokens;
|
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||||
|
|
||||||
std::vector<llm_graph_input_ptr> inputs;
|
std::vector<llm_graph_input_ptr> inputs;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue