graph : fix KQ mask, lora, cvec reuse checks (#19644)
* graph : fix KQ mask reuse condition * cont : dedup KQ mask build and can_reuse * cont : fix build * graph : fix adapter check for reuse
This commit is contained in:
parent
267ba5a1d9
commit
d5dfc33027
|
|
@ -39,6 +39,8 @@ private:
|
|||
std::vector<ggml_tensor *> tensors; // per layer
|
||||
};
|
||||
|
||||
using llama_adapter_cvec_ptr = std::shared_ptr<llama_adapter_cvec>;
|
||||
|
||||
//
|
||||
// llama_adapter_lora
|
||||
//
|
||||
|
|
@ -84,3 +86,4 @@ struct llama_adapter_lora {
|
|||
};
|
||||
|
||||
using llama_adapter_loras = std::unordered_map<llama_adapter_lora *, float>;
|
||||
using llama_adapter_loras_ptr = std::unique_ptr<llama_adapter_loras>;
|
||||
|
|
|
|||
|
|
@ -22,6 +22,8 @@ llama_context::llama_context(
|
|||
const llama_model & model,
|
||||
llama_context_params params) :
|
||||
model(model),
|
||||
cvec(std::make_unique<llama_adapter_cvec>()),
|
||||
loras(std::make_unique<llama_adapter_loras>()),
|
||||
balloc(std::make_unique<llama_batch_allocr>(model.hparams.n_pos_per_embd())) {
|
||||
// TODO warning when creating llama_context with awkward ctx size that is not a power of 2,
|
||||
// may need to be backend-dependent
|
||||
|
|
@ -1065,11 +1067,11 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
|
|||
return;
|
||||
}
|
||||
|
||||
loras.clear();
|
||||
loras.reset(new llama_adapter_loras());
|
||||
|
||||
for (size_t i = 0; i < n_adapters; i ++) {
|
||||
if (scales[i] != 0.0f) {
|
||||
loras[adapters[i]] = scales[i];
|
||||
loras->insert({adapters[i], scales[i]});
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1079,14 +1081,14 @@ void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_a
|
|||
bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) {
|
||||
LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters);
|
||||
|
||||
if (n_adapters != loras.size()) {
|
||||
if (n_adapters != loras->size()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < n_adapters; i ++) {
|
||||
auto it = loras.find(adapters[i]);
|
||||
auto it = loras->find(adapters[i]);
|
||||
|
||||
if (it == loras.end() || it->second != scales[i]) {
|
||||
if (it == loras->end() || it->second != scales[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
|
@ -1104,7 +1106,7 @@ bool llama_context::set_adapter_cvec(
|
|||
|
||||
// TODO: should we reserve?
|
||||
|
||||
return cvec.apply(model, data, len, n_embd, il_start, il_end);
|
||||
return cvec->apply(model, data, len, n_embd, il_start, il_end);
|
||||
}
|
||||
|
||||
llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) {
|
||||
|
|
@ -2081,8 +2083,8 @@ llm_graph_params llama_context::graph_params(
|
|||
/*.gtype =*/ gtype,
|
||||
/*.sched =*/ sched.get(),
|
||||
/*.backend_cpu =*/ backend_cpu,
|
||||
/*.cvec =*/ &cvec,
|
||||
/*.loras =*/ &loras,
|
||||
/*.cvec =*/ cvec.get(),
|
||||
/*.loras =*/ loras.get(),
|
||||
/*.mctx =*/ mctx,
|
||||
/*.cross =*/ &cross,
|
||||
/*.samplers =*/ sampling.samplers,
|
||||
|
|
|
|||
|
|
@ -256,9 +256,10 @@ private:
|
|||
|
||||
const llama_model & model;
|
||||
|
||||
llama_cparams cparams;
|
||||
llama_adapter_cvec cvec;
|
||||
llama_adapter_loras loras;
|
||||
llama_cparams cparams;
|
||||
|
||||
llama_adapter_cvec_ptr cvec;
|
||||
llama_adapter_loras_ptr loras;
|
||||
|
||||
llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,41 @@
|
|||
#include <sstream>
|
||||
#include <unordered_set>
|
||||
|
||||
// dedup helpers
|
||||
|
||||
static ggml_tensor * build_kq_mask(
|
||||
ggml_context * ctx,
|
||||
const llama_kv_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
}
|
||||
|
||||
static bool can_reuse_kq_mask(
|
||||
ggml_tensor * kq_mask,
|
||||
const llama_kv_cache_context * mctx,
|
||||
const llama_ubatch & ubatch,
|
||||
const llama_cparams & cparams) {
|
||||
const auto n_kv = mctx->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
bool res = true;
|
||||
|
||||
res &= (kq_mask->ne[0] == n_kv);
|
||||
res &= (kq_mask->ne[1] == n_tokens/n_stream);
|
||||
res &= (kq_mask->ne[2] == 1);
|
||||
res &= (kq_mask->ne[3] == n_stream);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
// impl
|
||||
|
||||
void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) {
|
||||
if (ubatch->token) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
|
@ -403,8 +438,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) {
|
|||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -424,8 +458,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) {
|
|||
|
||||
res &= self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -455,11 +488,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) {
|
|||
res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv();
|
||||
res &= self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
|
||||
res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv();
|
||||
res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams);
|
||||
res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams);
|
||||
|
||||
return res;
|
||||
}
|
||||
|
|
@ -521,8 +551,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) {
|
|||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
|
|
@ -565,8 +594,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) {
|
|||
|
||||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams);
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
||||
|
|
@ -625,8 +653,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
|
|||
res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
// swa tensors may not be allocated if there are no SWA attention layers
|
||||
|
|
@ -634,8 +661,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params)
|
|||
res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens;
|
||||
//res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there
|
||||
|
||||
res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv();
|
||||
res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens;
|
||||
res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams);
|
||||
}
|
||||
|
||||
res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs();
|
||||
|
|
@ -1891,14 +1917,11 @@ static std::unique_ptr<llm_graph_input_attn_kv> build_attn_inp_kv_impl(
|
|||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -1983,13 +2006,9 @@ static std::unique_ptr<llm_graph_input_attn_k> build_attn_inp_k_impl(
|
|||
{
|
||||
GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_n_kv();
|
||||
const auto n_tokens = ubatch.n_tokens;
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
|
||||
inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask;
|
||||
|
|
@ -2188,15 +2207,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
|
||||
auto inp = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, mctx_cur);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = mctx_cur->get_base()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask);
|
||||
ggml_set_name(inp->self_kq_mask, "self_kq_mask");
|
||||
|
||||
|
|
@ -2207,12 +2222,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const
|
|||
{
|
||||
GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA");
|
||||
|
||||
const auto n_kv = mctx_cur->get_swa()->get_n_kv();
|
||||
|
||||
inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp->self_kq_mask_swa);
|
||||
ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa");
|
||||
|
||||
|
|
@ -2374,27 +2387,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa()
|
|||
|
||||
auto inp_attn = std::make_unique<llm_graph_input_attn_kv_iswa>(hparams, cparams, attn_ctx);
|
||||
|
||||
const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq;
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_base()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask);
|
||||
|
||||
inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask;
|
||||
}
|
||||
|
||||
{
|
||||
const auto n_kv = attn_ctx->get_swa()->get_n_kv();
|
||||
|
||||
inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch);
|
||||
inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch);
|
||||
|
||||
inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream);
|
||||
inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams);
|
||||
ggml_set_input(inp_attn->self_kq_mask_swa);
|
||||
|
||||
inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa;
|
||||
|
|
|
|||
Loading…
Reference in New Issue