diff --git a/src/llama-context.cpp b/src/llama-context.cpp index a4d332b114..32a18e0534 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -1211,8 +1211,8 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } -static std::unordered_map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { - std::unordered_map seq_to_row; +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::map seq_to_row; // how many output tokens we have seen so far for this ubatch. uint32_t local = 0; for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { @@ -1230,10 +1230,10 @@ static std::unordered_map build_seq_to_output_row(const } static void copy_tensor_async_ints( - const std::unordered_map & tensor_map, + const std::map & tensor_map, llama_token * sampled, size_t sampled_size, - const std::unordered_map & seq_to_row, + const std::map & seq_to_row, ggml_backend_sched_t sched) { if (sampled == nullptr) { return; @@ -1256,11 +1256,11 @@ static void copy_tensor_async_ints( } static void copy_tensor_async_floats( - const std::unordered_map & tensor_map, + const std::map & tensor_map, float * dst, size_t stride, std::vector & counts, - const std::unordered_map & seq_to_row, + const std::map & seq_to_row, ggml_backend_sched_t sched) { if (dst == nullptr) { return; @@ -1287,11 +1287,11 @@ static void copy_tensor_async_floats( } static void copy_tensor_async_candidates( - const std::unordered_map & tensor_map, + const std::map & tensor_map, llama_token * dst, size_t stride, std::vector & counts, - const std::unordered_map & seq_to_row, + const std::map & seq_to_row, ggml_backend_sched_t sched) { if (dst == nullptr) { return; diff --git a/src/llama-context.h b/src/llama-context.h index e14367d55d..62c5ce5502 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -258,7 +258,7 @@ private: float * logits = nullptr; struct sampling_info { - std::unordered_map samplers; + std::map samplers; float * logits = nullptr; size_t logits_size = 0; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 0261b65bc4..80a562f3d8 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -2105,7 +2105,7 @@ void llm_graph_context::build_sampling() const { auto inp_sampling = std::make_unique(samplers); res->add_input(std::move(inp_sampling)); - std::unordered_map seq_to_logit_row; + std::map seq_to_logit_row; int32_t logit_row_idx = 0; for (uint32_t i = 0; i < ubatch.n_tokens; i++) { diff --git a/src/llama-graph.h b/src/llama-graph.h index 490d4fb00c..6c36ee6d05 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,6 +10,7 @@ #include #include #include +#include struct ggml_cgraph; struct ggml_context; @@ -385,14 +386,14 @@ public: class llm_graph_input_sampling : public llm_graph_input_i { public: - llm_graph_input_sampling(std::unordered_map samplers) : + llm_graph_input_sampling(std::map samplers) : samplers(std::move(samplers)) { } virtual ~llm_graph_input_sampling() = default; void set_input(const llama_ubatch * ubatch) override; bool can_reuse(const llm_graph_params & params) override; - std::unordered_map samplers; + std::map samplers; }; // @@ -428,11 +429,11 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; - std::unordered_map samplers; + std::map samplers; static bool samplers_equal( - const std::unordered_map & lhs, - const std::unordered_map & rhs) { + const std::map & lhs, + const std::map & rhs) { if (lhs.size() != rhs.size()) { return false; } @@ -484,17 +485,36 @@ struct llm_graph_params { return false; } + if (n_outputs != other.n_outputs) { + return false; + } + + if (!samplers_equal(samplers, other.samplers)) { + return false; + } + + if (samplers.size() > 0) { + if (!ubatch.data || !other.ubatch.data) { + return false; + } + + // check that the outputs are the same for all samplers + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.output[i] != other.ubatch.output[i] || + ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) { + return false; + } + } + } + return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross && - n_outputs == other.n_outputs && - samplers_equal(samplers, other.samplers); - + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross; } }; @@ -536,10 +556,10 @@ public: ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; - std::unordered_map t_sampled_logits; - std::unordered_map t_candidates; - std::unordered_map t_sampled; - std::unordered_map t_sampled_probs; + std::map t_sampled_logits; + std::map t_candidates; + std::map t_sampled; + std::map t_sampled_probs; std::vector inputs; @@ -616,7 +636,7 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; - std::unordered_map samplers; + std::map samplers; const llm_graph_cb & cb_func;