graph : respect sampler order for graph reuse
This commit is contained in:
parent
44d5c4b592
commit
804e7e3795
|
|
@ -1211,8 +1211,8 @@ int llama_context::encode(const llama_batch & batch_inp) {
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::unordered_map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
static std::map<llama_seq_id, uint32_t> build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) {
|
||||||
std::unordered_map<llama_seq_id, uint32_t> seq_to_row;
|
std::map<llama_seq_id, uint32_t> seq_to_row;
|
||||||
// how many output tokens we have seen so far for this ubatch.
|
// how many output tokens we have seen so far for this ubatch.
|
||||||
uint32_t local = 0;
|
uint32_t local = 0;
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
|
|
@ -1230,10 +1230,10 @@ static std::unordered_map<llama_seq_id, uint32_t> build_seq_to_output_row(const
|
||||||
}
|
}
|
||||||
|
|
||||||
static void copy_tensor_async_ints(
|
static void copy_tensor_async_ints(
|
||||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||||
llama_token * sampled,
|
llama_token * sampled,
|
||||||
size_t sampled_size,
|
size_t sampled_size,
|
||||||
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
|
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||||
ggml_backend_sched_t sched) {
|
ggml_backend_sched_t sched) {
|
||||||
if (sampled == nullptr) {
|
if (sampled == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -1256,11 +1256,11 @@ static void copy_tensor_async_ints(
|
||||||
}
|
}
|
||||||
|
|
||||||
static void copy_tensor_async_floats(
|
static void copy_tensor_async_floats(
|
||||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||||
float * dst,
|
float * dst,
|
||||||
size_t stride,
|
size_t stride,
|
||||||
std::vector<uint32_t> & counts,
|
std::vector<uint32_t> & counts,
|
||||||
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
|
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||||
ggml_backend_sched_t sched) {
|
ggml_backend_sched_t sched) {
|
||||||
if (dst == nullptr) {
|
if (dst == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
|
@ -1287,11 +1287,11 @@ static void copy_tensor_async_floats(
|
||||||
}
|
}
|
||||||
|
|
||||||
static void copy_tensor_async_candidates(
|
static void copy_tensor_async_candidates(
|
||||||
const std::unordered_map<llama_seq_id, ggml_tensor*> & tensor_map,
|
const std::map<llama_seq_id, ggml_tensor*> & tensor_map,
|
||||||
llama_token * dst,
|
llama_token * dst,
|
||||||
size_t stride,
|
size_t stride,
|
||||||
std::vector<uint32_t> & counts,
|
std::vector<uint32_t> & counts,
|
||||||
const std::unordered_map<llama_seq_id, uint32_t> & seq_to_row,
|
const std::map<llama_seq_id, uint32_t> & seq_to_row,
|
||||||
ggml_backend_sched_t sched) {
|
ggml_backend_sched_t sched) {
|
||||||
if (dst == nullptr) {
|
if (dst == nullptr) {
|
||||||
return;
|
return;
|
||||||
|
|
|
||||||
|
|
@ -258,7 +258,7 @@ private:
|
||||||
float * logits = nullptr;
|
float * logits = nullptr;
|
||||||
|
|
||||||
struct sampling_info {
|
struct sampling_info {
|
||||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||||
|
|
||||||
float * logits = nullptr;
|
float * logits = nullptr;
|
||||||
size_t logits_size = 0;
|
size_t logits_size = 0;
|
||||||
|
|
|
||||||
|
|
@ -2105,7 +2105,7 @@ void llm_graph_context::build_sampling() const {
|
||||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||||
res->add_input(std::move(inp_sampling));
|
res->add_input(std::move(inp_sampling));
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row;
|
std::map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||||
int32_t logit_row_idx = 0;
|
int32_t logit_row_idx = 0;
|
||||||
|
|
||||||
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
for (uint32_t i = 0; i < ubatch.n_tokens; i++) {
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
#include <memory>
|
#include <memory>
|
||||||
#include <set>
|
#include <set>
|
||||||
#include <functional>
|
#include <functional>
|
||||||
|
#include <map>
|
||||||
|
|
||||||
struct ggml_cgraph;
|
struct ggml_cgraph;
|
||||||
struct ggml_context;
|
struct ggml_context;
|
||||||
|
|
@ -385,14 +386,14 @@ public:
|
||||||
|
|
||||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||||
public:
|
public:
|
||||||
llm_graph_input_sampling(std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
|
llm_graph_input_sampling(std::map<llama_seq_id, llama_sampler *> samplers) :
|
||||||
samplers(std::move(samplers)) { }
|
samplers(std::move(samplers)) { }
|
||||||
virtual ~llm_graph_input_sampling() = default;
|
virtual ~llm_graph_input_sampling() = default;
|
||||||
|
|
||||||
void set_input(const llama_ubatch * ubatch) override;
|
void set_input(const llama_ubatch * ubatch) override;
|
||||||
bool can_reuse(const llm_graph_params & params) override;
|
bool can_reuse(const llm_graph_params & params) override;
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
@ -428,11 +429,11 @@ struct llm_graph_params {
|
||||||
const llama_memory_context_i * mctx;
|
const llama_memory_context_i * mctx;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||||
|
|
||||||
static bool samplers_equal(
|
static bool samplers_equal(
|
||||||
const std::unordered_map<llama_seq_id, llama_sampler *> & lhs,
|
const std::map<llama_seq_id, llama_sampler *> & lhs,
|
||||||
const std::unordered_map<llama_seq_id, llama_sampler *> & rhs) {
|
const std::map<llama_seq_id, llama_sampler *> & rhs) {
|
||||||
if (lhs.size() != rhs.size()) {
|
if (lhs.size() != rhs.size()) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
@ -484,17 +485,36 @@ struct llm_graph_params {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (n_outputs != other.n_outputs) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!samplers_equal(samplers, other.samplers)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (samplers.size() > 0) {
|
||||||
|
if (!ubatch.data || !other.ubatch.data) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// check that the outputs are the same for all samplers
|
||||||
|
for (uint32_t i = 0; i < ubatch.n_tokens; ++i) {
|
||||||
|
if (ubatch.output[i] != other.ubatch.output[i] ||
|
||||||
|
ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
cparams.embeddings == other.cparams.embeddings &&
|
cparams.embeddings == other.cparams.embeddings &&
|
||||||
cparams.causal_attn == other.cparams.causal_attn &&
|
cparams.causal_attn == other.cparams.causal_attn &&
|
||||||
arch == other.arch &&
|
arch == other.arch &&
|
||||||
gtype == other.gtype &&
|
gtype == other.gtype &&
|
||||||
cvec == other.cvec &&
|
cvec == other.cvec &&
|
||||||
loras == other.loras &&
|
loras == other.loras &&
|
||||||
cross == other.cross &&
|
cross == other.cross;
|
||||||
n_outputs == other.n_outputs &&
|
|
||||||
samplers_equal(samplers, other.samplers);
|
|
||||||
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -536,10 +556,10 @@ public:
|
||||||
ggml_tensor * t_embd = nullptr;
|
ggml_tensor * t_embd = nullptr;
|
||||||
ggml_tensor * t_embd_pooled = nullptr;
|
ggml_tensor * t_embd_pooled = nullptr;
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
std::map<llama_seq_id, ggml_tensor*> t_sampled_logits;
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_candidates;
|
std::map<llama_seq_id, ggml_tensor*> t_candidates;
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled;
|
std::map<llama_seq_id, ggml_tensor*> t_sampled;
|
||||||
std::unordered_map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
std::map<llama_seq_id, ggml_tensor*> t_sampled_probs;
|
||||||
|
|
||||||
std::vector<llm_graph_input_ptr> inputs;
|
std::vector<llm_graph_input_ptr> inputs;
|
||||||
|
|
||||||
|
|
@ -616,7 +636,7 @@ struct llm_graph_context {
|
||||||
const llama_memory_context_i * mctx;
|
const llama_memory_context_i * mctx;
|
||||||
const llama_cross * cross;
|
const llama_cross * cross;
|
||||||
|
|
||||||
std::unordered_map<llama_seq_id, llama_sampler*> samplers;
|
std::map<llama_seq_id, llama_sampler *> samplers;
|
||||||
|
|
||||||
const llm_graph_cb & cb_func;
|
const llm_graph_cb & cb_func;
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue