graph : respect sampler order for graph reuse
This commit is contained in:
parent
44d5c4b592
commit
804e7e3795
|
|
@ -1211,8 +1211,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
|||
return 0;
|
||||
}
|
||||
|
||||
static std::unordered_map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
||||
std::unordered_map<llama_seq_id, uint32_t> seq_to_row;
|
||||
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
||||
std::map<llama_seq_id, uint32_t> seq_to_row;
|
||||
// how many output tokens we have seen so far for this ubatch.
|
||||
uint32_t local = 0;
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
|
|
@ -1230,10 +1230,10 @@ static std::unordered_map<llama_seq_id, uint32_t> build_seq_to_output_row(const
|
|||
}
|
||||
|
||||
static void copy_tensor_async_ints(
|
||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * sampled,
|
||||
size_t sampled_size,
|
||||
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (sampled == nullptr) {
|
||||
return;
|
||||
|
|
@ -1256,11 +1256,11 @@ static void copy_tensor_async_ints(
|
|||
}
|
||||
|
||||
static void copy_tensor_async_floats(
|
||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
float * dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
return;
|
||||
|
|
@ -1287,11 +1287,11 @@ static void copy_tensor_async_floats(
|
|||
}
|
||||
|
||||
static void copy_tensor_async_candidates(
|
||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||
llama_token * dst,
|
||||
size_t stride,
|
||||
std::vector<uint32_t> & counts,
|
||||
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||
ggml_backend_sched_t sched) {
|
||||
if (dst == nullptr) {
|
||||
return;
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ private:
|
|||
float * logits = nullptr;
|
||||
|
||||
struct sampling_info {
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
float * logits = nullptr;
|
||||
size_t logits_size = 0;
|
||||
|
|
|
|||
|
|
@ -2105,7 +2105,7 @@ void llm_graph_context::build_sampling() const {
|
|||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
#include <memory>
|
||||
#include <set>
|
||||
#include <functional>
|
||||
#include <map>
|
||||
|
||||
struct ggml_cgraph;
|
||||
struct ggml_context;
|
||||
|
|
@ -385,14 +386,14 @@ public:
|
|||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
|
||||
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||
samplers(std::move(samplers)) { }
|
||||
virtual ~llm_graph_input_sampling() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
};
|
||||
|
||||
//
|
||||
|
|
@ -428,11 +429,11 @@ struct llm_graph_params {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
static bool samplers_equal(
|
||||
const std::unordered_map<llama_seq_id, llama_sampler *> & lhs,
|
||||
const std::unordered_map<llama_seq_id, llama_sampler *> & rhs) {
|
||||
const std::map<llama_seq_id, llama_sampler *> & lhs,
|
||||
const std::map<llama_seq_id, llama_sampler *> & rhs) {
|
||||
if (lhs.size() != rhs.size()) {
|
||||
return false;
|
||||
}
|
||||
|
|
@ -484,6 +485,28 @@ struct llm_graph_params {
|
|||
return false;
|
||||
}
|
||||
|
||||
if (n_outputs != other.n_outputs) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (!samplers_equal(samplers, other.samplers)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (samplers.size() > 0) {
|
||||
if (!ubatch.data || !other.ubatch.data) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// check that the outputs are the same for all samplers
|
||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||
if (ubatch.output[i] != other.ubatch.output[i] ||
|
||||
ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
cparams.embeddings == other.cparams.embeddings &&
|
||||
cparams.causal_attn == other.cparams.causal_attn &&
|
||||
|
|
@ -491,10 +514,7 @@ struct llm_graph_params {
|
|||
gtype == other.gtype &&
|
||||
cvec == other.cvec &&
|
||||
loras == other.loras &&
|
||||
cross == other.cross &&
|
||||
n_outputs == other.n_outputs &&
|
||||
samplers_equal(samplers, other.samplers);
|
||||
|
||||
cross == other.cross;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -536,10 +556,10 @@ public:
|
|||
ggml_tensor * t_embd = nullptr;
|
||||
ggml_tensor * t_embd_pooled = nullptr;
|
||||
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||
|
||||
std::vector<llm_graph_input_ptr> inputs;
|
||||
|
||||
|
|
@ -616,7 +636,7 @@ struct llm_graph_context {
|
|||
const llama_memory_context_i * mctx;
|
||||
const llama_cross * cross;
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
||||
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||
|
||||
const llm_graph_cb & cb_func;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue