graph : make the compute graph constant with respect to active samplers
This commit is contained in:
parent
0ecee8be37
commit
c02654eb7d
|
|
@ -1241,7 +1241,10 @@ static void copy_tensor_async_ints(
|
|||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
GGML_ASSERT(it != seq_to_row.end());
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < sampled_size);
|
||||
|
||||
|
|
@ -1265,7 +1268,10 @@ static void copy_tensor_async_floats(
|
|||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
GGML_ASSERT(it != seq_to_row.end());
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < counts.size());
|
||||
|
||||
|
|
@ -1293,7 +1299,10 @@ static void copy_tensor_async_candidates(
|
|||
|
||||
for (const auto & [seq_id, tensor] : tensor_map) {
|
||||
auto it = seq_to_row.find(seq_id);
|
||||
GGML_ASSERT(it != seq_to_row.end());
|
||||
if (it == seq_to_row.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
const uint32_t row = it->second;
|
||||
GGML_ASSERT(row < counts.size());
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@
|
|||
#include <cassert>
|
||||
#include <cmath>
|
||||
#include <cstring>
|
||||
#include <unordered_set>
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
|
|
@ -466,8 +467,22 @@ void llm_graph_input_mem_hybrid::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
|
||||
void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
||||
GGML_UNUSED(ubatch);
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
// set the inputs only for the active samplers in the current ubatch
|
||||
std::unordered_set<llama_seq_id> active_samplers;
|
||||
for (uint32_t i = 0; i < ubatch->n_tokens; i++) {
|
||||
if (ubatch->output[i]) {
|
||||
llama_seq_id seq_id = ubatch->seq_id[i][0];
|
||||
active_samplers.insert(seq_id);
|
||||
}
|
||||
}
|
||||
|
||||
for (auto seq_id : active_samplers) {
|
||||
if (samplers.find(seq_id) == samplers.end()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto & sampler = samplers[seq_id];
|
||||
|
||||
if (sampler->iface->backend_set_input) {
|
||||
sampler->iface->backend_set_input(sampler);
|
||||
}
|
||||
|
|
@ -475,11 +490,10 @@ void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) {
|
|||
}
|
||||
|
||||
bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) {
|
||||
if (params.samplers.empty()) {
|
||||
return true;
|
||||
if (samplers.size() != params.samplers.size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// TODO: this check is incorrect - it has to check against the last set of samplers that were used in the previous graph
|
||||
for (const auto & [seq_id, sampler] : params.samplers) {
|
||||
if (samplers[seq_id] != sampler) {
|
||||
return false;
|
||||
|
|
@ -1830,8 +1844,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv");
|
||||
}
|
||||
|
||||
{
|
||||
|
|
@ -1844,8 +1860,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens/n_stream, GGML_KQ_MASK_PAD), 1, n_stream);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa;
|
||||
ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv");
|
||||
}
|
||||
|
||||
return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp));
|
||||
|
|
@ -2084,6 +2102,9 @@ void llm_graph_context::build_sampling() const {
|
|||
return;
|
||||
}
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
|
||||
std::unordered_map<llama_seq_id, int32_t> seq_to_logit_row;
|
||||
int32_t logit_row_idx = 0;
|
||||
|
||||
|
|
@ -2095,30 +2116,21 @@ void llm_graph_context::build_sampling() const {
|
|||
}
|
||||
}
|
||||
|
||||
if (seq_to_logit_row.empty()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1)
|
||||
ggml_tensor * logits_t = res->t_logits;
|
||||
GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor");
|
||||
|
||||
const int64_t n_vocab = logits_t->ne[0];
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> active_samplers;
|
||||
// add a dummy row of logits
|
||||
// this trick makes the graph static, regardless of which samplers are activated
|
||||
// this is important in order to minimize graph reallocations
|
||||
ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0);
|
||||
|
||||
for (const auto & [seq_id, sampler] : samplers) {
|
||||
// Only process samplers for sequences that are in the current batch
|
||||
auto it = seq_to_logit_row.find(seq_id);
|
||||
if (it == seq_to_logit_row.end()) {
|
||||
continue;
|
||||
}
|
||||
const auto it = seq_to_logit_row.find(seq_id);
|
||||
|
||||
active_samplers[seq_id] = sampler;
|
||||
// inactive samplers alawys work on the first row
|
||||
const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0;
|
||||
|
||||
const int32_t row_idx = it->second;
|
||||
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, n_vocab, row_idx * logits_t->nb[1]);
|
||||
ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]);
|
||||
ggml_format_name(logits_seq, "logits_seq_%d", seq_id);
|
||||
|
||||
struct llama_sampler_data data = {
|
||||
|
|
@ -2163,9 +2175,6 @@ void llm_graph_context::build_sampling() const {
|
|||
}
|
||||
}
|
||||
*/
|
||||
|
||||
auto inp_sampling = std::make_unique<llm_graph_input_sampling>(n_vocab, false, active_samplers);
|
||||
res->add_input(std::move(inp_sampling));
|
||||
}
|
||||
|
||||
int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) {
|
||||
|
|
|
|||
|
|
@ -385,19 +385,13 @@ public:
|
|||
|
||||
class llm_graph_input_sampling : public llm_graph_input_i {
|
||||
public:
|
||||
llm_graph_input_sampling(int32_t n_vocab, bool sorted,
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
|
||||
n_vocab(n_vocab), sorted_value(sorted), samplers(std::move(samplers)) { }
|
||||
llm_graph_input_sampling(std::unordered_map<llama_seq_id, llama_sampler *> samplers) :
|
||||
samplers(std::move(samplers)) { }
|
||||
virtual ~llm_graph_input_sampling() = default;
|
||||
|
||||
void set_input(const llama_ubatch * ubatch) override;
|
||||
bool can_reuse(const llm_graph_params & params) override;
|
||||
|
||||
int32_t n_vocab;
|
||||
bool sorted_value;
|
||||
ggml_tensor * size = nullptr; // I32 [1]
|
||||
ggml_tensor * sorted = nullptr; // I32 [1]
|
||||
|
||||
std::unordered_map<llama_seq_id, llama_sampler *> samplers;
|
||||
};
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue