diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 34dfcd4724..2bce5c3485 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -127,6 +127,7 @@ llama_context::llama_context( } cparams.flash_attn = params.flash_attn_type != LLAMA_FLASH_ATTN_TYPE_DISABLED; + cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -136,6 +137,9 @@ llama_context::llama_context( cparams.op_offload = params.op_offload; cparams.kv_unified = params.kv_unified; + // intialized later + cparams.pipeline_parallel = false; + { const char * LLAMA_GRAPH_REUSE_DISABLE = getenv("LLAMA_GRAPH_REUSE_DISABLE"); graph_reuse_disable = LLAMA_GRAPH_REUSE_DISABLE ? (atoi(LLAMA_GRAPH_REUSE_DISABLE) != 0) : graph_reuse_disable; @@ -280,16 +284,6 @@ llama_context::llama_context( LLAMA_LOG_DEBUG("%s: backend_ptrs.size() = %zu\n", __func__, backend_ptrs.size()); - const uint32_t n_seqs = cparams.n_seq_max; - const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); - - const size_t max_nodes = this->graph_max_nodes(n_tokens); - - LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); - - gf_res_prev.reset(new llm_graph_result(max_nodes)); - gf_res_reserve.reset(new llm_graph_result(max_nodes)); - // TODO: move these checks to ggml_backend_sched // enabling pipeline parallelism in the scheduler increases memory usage, so it is only done when necessary bool pipeline_parallel = @@ -318,143 +312,19 @@ llama_context::llama_context( } } - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, pipeline_parallel, cparams.op_offload)); + cparams.pipeline_parallel = pipeline_parallel; - if (pipeline_parallel) { + if (cparams.pipeline_parallel) { LLAMA_LOG_INFO("%s: pipeline parallelism enabled (n_copies=%d)\n", __func__, ggml_backend_sched_get_n_copies(sched.get())); } - llama_memory_context_ptr mctx; - if (memory) { - LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); - mctx = memory->init_full(); - if (!mctx) { - throw std::runtime_error("failed to initialize memory module"); + reserve(); + + if (cparams.flash_attn) { + if (ggml_is_quantized(params.type_v)) { + throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); } } - - cross.v_embd.clear(); - - // avoid reserving graphs with zero outputs - assume one output per sequence - n_outputs = n_seqs; - - LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); - - // resolve automatic Flash Attention use - if (params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to split graph for Flash Attention check"); - } - - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; - bool fa_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_FLASH_ATTN_EXT) { - continue; - } - ggml_backend_dev_t device_fa = ggml_backend_get_device( - ggml_backend_sched_get_tensor_backend(sched.get(), n)); - - // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_fa != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); - // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways - fa_device_mismatch = true; - break; - } - } - if (fa_device_mismatch) { - cparams.flash_attn = false; - LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); - if (ggml_is_quantized(params.type_v)) { - throw std::runtime_error("quantized V cache was requested, but this requires Flash Attention"); - } - } else { - cparams.flash_attn = true; - LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); - } - } - - // reserve worst-case graph - int n_splits_pp = -1; - int n_nodes_pp = -1; - - int n_splits_tg = -1; - int n_nodes_tg = -1; - - // reserve pp (prompt processing) graph first so that buffers are only allocated once - { - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), - model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); - if (!gf) { - if (pipeline_parallel) { - LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); - sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); - gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); - } - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_pp = ggml_graph_n_nodes(gf); - } - - // reserve with tg (token generation) graph to get the number of splits and nodes - { - auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute tg buffers"); - } - - n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); - n_nodes_tg = ggml_graph_n_nodes(gf); - } - - // reserve again with pp graph to avoid ggml-alloc reallocations during inference - { - // TODO: not sure if the following graph would be worster case for multi-stream KV caches: - // - // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); - // - auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); - if (!gf) { - throw std::runtime_error("failed to allocate compute pp buffers"); - } - } - - for (size_t i = 0; i < backend_ptrs.size(); ++i) { - ggml_backend_t backend = backend_ptrs[i]; - ggml_backend_buffer_type_t buft = backend_buft[i]; - if (!model.hparams.no_alloc) { - backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); - } - if (backend_buf_exp_size[i] > 1) { - LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, - ggml_backend_buft_name(buft), - backend_buf_exp_size[i] / 1024.0 / 1024.0); - } - } - - if (n_nodes_pp == n_nodes_tg) { - LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); - } else { - LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); - } - - if (n_splits_pp == n_splits_tg) { - LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); - } else { - LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); - } } } @@ -478,6 +348,154 @@ llama_context::~llama_context() { ggml_opt_free(opt_ctx); } +void llama_context::reserve() { + LLAMA_LOG_INFO("%s: reserving ...\n", __func__); + + const uint32_t n_seqs = cparams.n_seq_max; + const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); + + const size_t max_nodes = this->graph_max_nodes(n_tokens); + + LLAMA_LOG_DEBUG("%s: max_nodes = %zu\n", __func__, max_nodes); + + gf_res_prev.reset(new llm_graph_result(max_nodes)); + gf_res_reserve.reset(new llm_graph_result(max_nodes)); + + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, cparams.pipeline_parallel, cparams.op_offload)); + + llama_memory_context_ptr mctx; + if (memory) { + LLAMA_LOG_DEBUG("%s: reserving full memory module\n", __func__); + mctx = memory->init_full(); + if (!mctx) { + throw std::runtime_error("failed to initialize memory module"); + } + } + + cross.v_embd.clear(); + + // avoid reserving graphs with zero outputs - assume one output per sequence + n_outputs = n_seqs; + + LLAMA_LOG_DEBUG("%s: worst-case: n_tokens = %d, n_seqs = %d, n_outputs = %d\n", __func__, n_tokens, n_seqs, n_outputs); + + // resolve automatic Flash Attention use + if (cparams.auto_fa) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to split graph for Flash Attention check"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FATTN) + 1; + bool fa_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_FLASH_ATTN_EXT) { + continue; + } + ggml_backend_dev_t device_fa = ggml_backend_get_device( + ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + // TODO: instead of the tensor names, use a map to keep track of which (FA) tensors belong to which layer + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FATTN "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_fa != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the Flash Attention tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_fa)); + // FIXME: fa_device_mismatch logic is wrong for --no-kv-offload, but this is broken anyways + fa_device_mismatch = true; + break; + } + } + if (fa_device_mismatch) { + cparams.flash_attn = false; + LLAMA_LOG_WARN("%s: Flash Attention was auto, set to disabled\n", __func__); + } else { + cparams.flash_attn = true; + LLAMA_LOG_INFO("%s: Flash Attention was auto, set to enabled\n", __func__); + } + + cparams.auto_fa = false; + } + + // reserve worst-case graph + int n_splits_pp = -1; + int n_nodes_pp = -1; + + int n_splits_tg = -1; + int n_nodes_tg = -1; + + // reserve pp (prompt processing) graph first so that buffers are only allocated once + { + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), + model.hparams.no_alloc, model.hparams.no_alloc ? backend_buf_exp_size.data() : nullptr); + if (!gf) { + if (cparams.pipeline_parallel) { + LLAMA_LOG_WARN("%s: compute buffer allocation failed, retrying without pipeline parallelism\n", __func__); + cparams.pipeline_parallel = false; + sched.reset(ggml_backend_sched_new(backend_ptrs.data(), backend_buft.data(), backend_ptrs.size(), max_nodes, false, cparams.op_offload)); + gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); + } + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + n_splits_pp = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_pp = ggml_graph_n_nodes(gf); + } + + // reserve with tg (token generation) graph to get the number of splits and nodes + { + auto * gf = graph_reserve(n_seqs, n_seqs, n_seqs, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute tg buffers"); + } + + n_splits_tg = ggml_backend_sched_get_n_splits(sched.get()); + n_nodes_tg = ggml_graph_n_nodes(gf); + } + + // reserve again with pp graph to avoid ggml-alloc reallocations during inference + { + // TODO: not sure if the following graph would be worster case for multi-stream KV caches: + // + // auto * gf = graph_reserve(n_tokens, 1, n_tokens, mctx.get()); + // + auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get(), model.hparams.no_alloc); + if (!gf) { + throw std::runtime_error("failed to allocate compute pp buffers"); + } + } + + for (size_t i = 0; i < backend_ptrs.size(); ++i) { + ggml_backend_t backend = backend_ptrs[i]; + ggml_backend_buffer_type_t buft = backend_buft[i]; + if (!model.hparams.no_alloc) { + backend_buf_exp_size[i] = ggml_backend_sched_get_buffer_size(sched.get(), backend); + } + if (backend_buf_exp_size[i] > 1) { + LLAMA_LOG_INFO("%s: %10s compute buffer size = %8.2f MiB\n", __func__, + ggml_backend_buft_name(buft), + backend_buf_exp_size[i] / 1024.0 / 1024.0); + } + } + + if (n_nodes_pp == n_nodes_tg) { + LLAMA_LOG_INFO("%s: graph nodes = %d\n", __func__, n_nodes_pp); + } else { + LLAMA_LOG_INFO("%s: graph nodes = %d (with bs=%d), %d (with bs=1)\n", __func__, n_nodes_pp, n_tokens, n_nodes_tg); + } + + if (n_splits_pp == n_splits_tg) { + LLAMA_LOG_INFO("%s: graph splits = %d\n", __func__, n_splits_pp); + } else { + LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); + } +} + void llama_context::synchronize() { ggml_backend_sched_synchronize(sched.get()); @@ -753,12 +771,17 @@ void llama_context::set_embeddings(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.embeddings = value; + + // TODO: not sure yet if we want to reserve here + //reserve(); } void llama_context::set_causal_attn(bool value) { LLAMA_LOG_DEBUG("%s: value = %d\n", __func__, value); cparams.causal_attn = value; + + reserve(); } void llama_context::set_warmup(bool value) { @@ -773,6 +796,8 @@ void llama_context::set_adapter_lora( LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); loras[adapter] = scale; + + reserve(); } bool llama_context::rm_adapter_lora( @@ -782,6 +807,9 @@ bool llama_context::rm_adapter_lora( auto pos = loras.find(adapter); if (pos != loras.end()) { loras.erase(pos); + + reserve(); + return true; } @@ -792,6 +820,8 @@ void llama_context::clear_adapter_lora() { LLAMA_LOG_DEBUG("%s: call\n", __func__); loras.clear(); + + reserve(); } bool llama_context::apply_adapter_cvec( @@ -802,7 +832,13 @@ bool llama_context::apply_adapter_cvec( int32_t il_end) { LLAMA_LOG_DEBUG("%s: il_start = %d, il_end = %d\n", __func__, il_start, il_end); - return cvec.apply(model, data, len, n_embd, il_start, il_end); + bool res = cvec.apply(model, data, len, n_embd, il_start, il_end); + + if (res) { + reserve(); + } + + return res; } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { diff --git a/src/llama-context.h b/src/llama-context.h index c31101330e..10be4ebee6 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -40,6 +40,7 @@ struct llama_context { ~llama_context(); + void reserve(); void synchronize(); const llama_model & get_model() const; diff --git a/src/llama-cparams.h b/src/llama-cparams.h index fcef8fa976..2da3bbd6f9 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -30,10 +30,12 @@ struct llama_cparams { bool causal_attn; bool offload_kqv; bool flash_attn; + bool auto_fa; bool no_perf; bool warmup; bool op_offload; bool kv_unified; + bool pipeline_parallel; enum llama_pooling_type pooling_type;