cont : naming
This commit is contained in:
parent
c187003d81
commit
80742cbaeb
|
|
@ -106,16 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
|
||||||
}
|
}
|
||||||
|
|
||||||
static llama_sampler_i llama_sampler_llg_i = {
|
static llama_sampler_i llama_sampler_llg_i = {
|
||||||
/* .name = */ llama_sampler_llg_name,
|
/* .name = */ llama_sampler_llg_name,
|
||||||
/* .accept = */ llama_sampler_llg_accept_impl,
|
/* .accept = */ llama_sampler_llg_accept_impl,
|
||||||
/* .apply = */ llama_sampler_llg_apply,
|
/* .apply = */ llama_sampler_llg_apply,
|
||||||
/* .reset = */ llama_sampler_llg_reset,
|
/* .reset = */ llama_sampler_llg_reset,
|
||||||
/* .clone = */ llama_sampler_llg_clone,
|
/* .clone = */ llama_sampler_llg_clone,
|
||||||
/* .free = */ llama_sampler_llg_free,
|
/* .free = */ llama_sampler_llg_free,
|
||||||
/* .apply_ggml = */ NULL,
|
/* .backend_init = */ NULL,
|
||||||
/* .accept_ggml = */ NULL,
|
/* .backend_accept = */ NULL,
|
||||||
/* .set_input_ggml = */ NULL,
|
/* .backend_apply = */ NULL,
|
||||||
/* .set_backend_context = */ NULL,
|
/* .backend_set_input = */ NULL,
|
||||||
};
|
};
|
||||||
|
|
||||||
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,
|
||||||
|
|
|
||||||
|
|
@ -1374,10 +1374,6 @@ extern "C" {
|
||||||
//
|
//
|
||||||
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
|
LLAMA_API struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab);
|
||||||
|
|
||||||
//
|
|
||||||
// Backend samplers
|
|
||||||
//
|
|
||||||
|
|
||||||
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
// Returns the seed used by the sampler if applicable, LLAMA_DEFAULT_SEED otherwise
|
||||||
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
LLAMA_API uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ bool llama_batch_allocr::init(
|
||||||
uint32_t n_embd,
|
uint32_t n_embd,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool output_all,
|
bool output_all,
|
||||||
bool backend_sampling) {
|
bool sampling) {
|
||||||
clear();
|
clear();
|
||||||
|
|
||||||
batch = batch_inp;
|
batch = batch_inp;
|
||||||
|
|
@ -146,7 +146,7 @@ bool llama_batch_allocr::init(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (backend_sampling) {
|
if (sampling) {
|
||||||
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
std::vector<int32_t> seq_output_count(n_seq_max, 0);
|
||||||
|
|
||||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||||
|
|
@ -157,7 +157,7 @@ bool llama_batch_allocr::init(
|
||||||
const llama_seq_id seq_id = batch.seq_id[i][s];
|
const llama_seq_id seq_id = batch.seq_id[i][s];
|
||||||
seq_output_count[seq_id]++;
|
seq_output_count[seq_id]++;
|
||||||
if (seq_output_count[seq_id] > 1) {
|
if (seq_output_count[seq_id] > 1) {
|
||||||
LLAMA_LOG_ERROR("%s: backend sampling allows at most one output token per sequence (%d)\n", __func__, seq_id);
|
LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (%d)\n", __func__, seq_id);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -80,7 +80,7 @@ public:
|
||||||
uint32_t n_embd,
|
uint32_t n_embd,
|
||||||
uint32_t n_seq_max,
|
uint32_t n_seq_max,
|
||||||
bool output_all,
|
bool output_all,
|
||||||
bool backend_sampling = false);
|
bool sampling = false);
|
||||||
|
|
||||||
const llama_batch & get_batch() const;
|
const llama_batch & get_batch() const;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1322,12 +1322,12 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
|
|
||||||
// when computing embeddings, all tokens are output
|
// when computing embeddings, all tokens are output
|
||||||
const bool output_all = cparams.embeddings;
|
const bool output_all = cparams.embeddings;
|
||||||
const bool has_backend_samplers = !sampling.samplers.empty();
|
const bool has_samplers = !sampling.samplers.empty();
|
||||||
|
|
||||||
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
if (!balloc->init(batch_inp, vocab, memory.get(), n_embd,
|
||||||
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max,
|
||||||
output_all,
|
output_all,
|
||||||
has_backend_samplers)) {
|
has_samplers)) {
|
||||||
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__);
|
||||||
return -1;
|
return -1;
|
||||||
}
|
}
|
||||||
|
|
@ -1415,10 +1415,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
|
|
||||||
int64_t n_outputs_prev = 0;
|
int64_t n_outputs_prev = 0;
|
||||||
|
|
||||||
// This flag indicates whether a backend sampler has actually sampled a specific
|
|
||||||
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
|
|
||||||
bool backend_has_sampled = false;
|
|
||||||
|
|
||||||
do {
|
do {
|
||||||
const auto & ubatch = mctx->get_ubatch();
|
const auto & ubatch = mctx->get_ubatch();
|
||||||
|
|
||||||
|
|
@ -1477,9 +1473,11 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
// ggml_graph_dump_dot(gf, NULL, "llama.dot");
|
||||||
//}
|
//}
|
||||||
|
|
||||||
backend_has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
// This flag indicates whether a backend sampler has actually sampled a specific
|
||||||
|
// token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings.
|
||||||
|
const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty();
|
||||||
|
|
||||||
if (has_backend_samplers && backend_has_sampled) {
|
if (has_samplers && has_sampled) {
|
||||||
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev);
|
||||||
const auto stride = n_vocab;
|
const auto stride = n_vocab;
|
||||||
|
|
||||||
|
|
@ -1495,7 +1493,6 @@ int llama_context::decode(const llama_batch & batch_inp) {
|
||||||
// async copy the candidate token ids from the backend to the host.
|
// async copy the candidate token ids from the backend to the host.
|
||||||
// These are needed by CPU samplers to map probability/logit indices to vocab token ids.
|
// These are needed by CPU samplers to map probability/logit indices to vocab token ids.
|
||||||
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
|
copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get());
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
auto * t_logits = res->get_logits();
|
auto * t_logits = res->get_logits();
|
||||||
|
|
@ -1661,8 +1658,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check which sampling modes are needed by sequences in the current batch.
|
// Check which sampling modes are needed by sequences in the current batch.
|
||||||
bool batch_has_backend_sampling = false;
|
bool batch_has_sampling = false;
|
||||||
bool batch_needs_cpu_logits = false;
|
bool batch_needs_cpu_logits = false;
|
||||||
|
|
||||||
if (batch.logits) {
|
if (batch.logits) {
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||||
|
|
@ -1672,7 +1669,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
||||||
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
for (int32_t j = 0; j < batch.n_seq_id[i]; j++) {
|
||||||
llama_seq_id seq_id = batch.seq_id[i][j];
|
llama_seq_id seq_id = batch.seq_id[i][j];
|
||||||
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
if (sampling.samplers.find(seq_id) != sampling.samplers.end()) {
|
||||||
batch_has_backend_sampling = true;
|
batch_has_sampling = true;
|
||||||
} else {
|
} else {
|
||||||
batch_needs_cpu_logits = true;
|
batch_needs_cpu_logits = true;
|
||||||
}
|
}
|
||||||
|
|
@ -1691,7 +1688,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
||||||
logits_size = (has_logits && batch_needs_cpu_logits) ? n_vocab*n_outputs_max : 0;
|
logits_size = (has_logits && batch_needs_cpu_logits) ? n_vocab*n_outputs_max : 0;
|
||||||
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
embd_size = has_embd ? n_embd*n_outputs_max : 0;
|
||||||
|
|
||||||
if (!batch_has_backend_sampling) {
|
if (!batch_has_sampling) {
|
||||||
sampling.logits_size = 0;
|
sampling.logits_size = 0;
|
||||||
sampling.probs_size = 0;
|
sampling.probs_size = 0;
|
||||||
sampling.sampled_size = 0;
|
sampling.sampled_size = 0;
|
||||||
|
|
@ -1762,7 +1759,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & ba
|
||||||
embd = has_embd ? (float *) (base + offset) : nullptr;
|
embd = has_embd ? (float *) (base + offset) : nullptr;
|
||||||
offset += embd_size * sizeof(float);
|
offset += embd_size * sizeof(float);
|
||||||
|
|
||||||
if (batch_has_backend_sampling) {
|
if (batch_has_sampling) {
|
||||||
sampling.logits = (float *) (base + offset);
|
sampling.logits = (float *) (base + offset);
|
||||||
offset += sampling.logits_size * sizeof(float);
|
offset += sampling.logits_size * sizeof(float);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -66,17 +66,17 @@ struct llama_context {
|
||||||
float * get_embeddings_ith(int32_t i);
|
float * get_embeddings_ith(int32_t i);
|
||||||
float * get_embeddings_seq(llama_seq_id seq_id);
|
float * get_embeddings_seq(llama_seq_id seq_id);
|
||||||
|
|
||||||
llama_token * get_sampled_tokens();
|
llama_token * get_sampled_tokens();
|
||||||
llama_token get_sampled_token_ith(int32_t idx);
|
llama_token get_sampled_token_ith(int32_t idx);
|
||||||
|
|
||||||
float * get_sampled_logits_ith(int32_t idx);
|
float * get_sampled_logits_ith(int32_t idx);
|
||||||
size_t get_sampled_logits_count(int32_t idx);
|
size_t get_sampled_logits_count(int32_t idx);
|
||||||
|
|
||||||
float * get_sampled_probs_ith(int32_t idx);
|
float * get_sampled_probs_ith(int32_t idx);
|
||||||
size_t get_sampled_probs_count(int32_t idx);
|
size_t get_sampled_probs_count(int32_t idx);
|
||||||
|
|
||||||
const llama_token * get_sampled_candidates_ith(int32_t idx);
|
const llama_token * get_sampled_candidates_ith(int32_t idx);
|
||||||
size_t get_sampled_candidates_count(int32_t idx);
|
size_t get_sampled_candidates_count(int32_t idx);
|
||||||
|
|
||||||
void attach_threadpool(
|
void attach_threadpool(
|
||||||
ggml_threadpool_t threadpool,
|
ggml_threadpool_t threadpool,
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue