cuda : fix nkvo, offload and cuda graph node properties matching (#19165)
* cuda : fix nkvo * cont : more robust cuda graph node property matching * cont : restore pre-leafs implementation * cont : comments + static_assert
This commit is contained in:
parent
7b7ae857f6
commit
4fdbc1e4db
|
|
@ -1122,15 +1122,17 @@ struct ggml_tensor_extra_gpu {
|
|||
#endif
|
||||
|
||||
struct ggml_cuda_graph_node_properties {
|
||||
void * node_address;
|
||||
void * node_data;
|
||||
ggml_op node_op;
|
||||
int32_t flags;
|
||||
int64_t ne[GGML_MAX_DIMS];
|
||||
size_t nb[GGML_MAX_DIMS];
|
||||
void * src_address[GGML_MAX_SRC];
|
||||
void * src_data[GGML_MAX_SRC];
|
||||
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
|
||||
};
|
||||
|
||||
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
|
||||
|
||||
struct ggml_cuda_graph {
|
||||
#ifdef USE_CUDA_GRAPH
|
||||
~ggml_cuda_graph() {
|
||||
|
|
@ -1150,6 +1152,12 @@ struct ggml_cuda_graph {
|
|||
int number_consecutive_updates = 0;
|
||||
std::vector<ggml_cuda_graph_node_properties> props;
|
||||
|
||||
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
|
||||
// they properties also have to match in order to be able to safely reuse a CUDA graph
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
|
||||
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
|
||||
std::vector<ggml_cuda_graph_node_properties> extra;
|
||||
|
||||
void record_update(bool use_graph, bool update_required) {
|
||||
if (use_graph && update_required) {
|
||||
number_consecutive_updates++;
|
||||
|
|
|
|||
|
|
@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
}
|
||||
}
|
||||
|
||||
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
|
||||
|
||||
const int cc = ggml_cuda_info().devices[device].cc;
|
||||
|
||||
switch (K->ne[0]) {
|
||||
|
|
@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
|
|||
if (!gqa_opt_applies) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
if (!V_is_K_view) {
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
}
|
||||
break;
|
||||
default:
|
||||
return BEST_FATTN_KERNEL_NONE;
|
||||
|
|
|
|||
|
|
@ -70,17 +70,18 @@
|
|||
#include <condition_variable>
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <float.h>
|
||||
#include <cfloat>
|
||||
#include <initializer_list>
|
||||
#include <limits>
|
||||
#include <map>
|
||||
#include <memory>
|
||||
#include <mutex>
|
||||
#include <stdarg.h>
|
||||
#include <stdio.h>
|
||||
#include <stdlib.h>
|
||||
#include <cstdarg>
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_set>
|
||||
|
||||
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||
|
||||
|
|
@ -2916,7 +2917,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
|
|||
}
|
||||
|
||||
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
|
||||
props->node_address = node->data;
|
||||
memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
|
||||
props->node_data = node->data;
|
||||
props->node_op = node->op;
|
||||
props->flags = node->flags;
|
||||
for (int i = 0; i < GGML_MAX_DIMS; i++) {
|
||||
|
|
@ -2924,14 +2926,17 @@ static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties
|
|||
props->nb[i] = node->nb[i];
|
||||
}
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
|
||||
if (!node->src[i]) {
|
||||
continue;
|
||||
}
|
||||
|
||||
props->src_data[i] = node->src[i]->data;
|
||||
}
|
||||
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
|
||||
}
|
||||
|
||||
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
|
||||
if (node->data != props->node_address &&
|
||||
node->op != GGML_OP_VIEW) {
|
||||
if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -2948,12 +2953,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
|
|||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (node->src[i] &&
|
||||
node->src[i]->data != props->src_address[i] &&
|
||||
node->op != GGML_OP_VIEW
|
||||
) {
|
||||
return false;
|
||||
if (node->op != GGML_OP_VIEW) {
|
||||
for (int i = 0; i < GGML_MAX_SRC; i++) {
|
||||
if (!node->src[i]) {
|
||||
if (props->src_data[i] != nullptr) {
|
||||
return false;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
if (node->src[i]->data != props->src_data[i]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -2974,7 +2985,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
|
|||
}
|
||||
|
||||
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
|
||||
|
||||
bool res = false;
|
||||
|
||||
const void * graph_key = ggml_cuda_graph_get_key(cgraph);
|
||||
|
|
@ -2985,15 +2995,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
|||
}
|
||||
|
||||
// Check if the graph size has changed
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) {
|
||||
if (graph->props.size() != (size_t)cgraph->n_nodes) {
|
||||
res = true;
|
||||
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs);
|
||||
graph->props.resize(cgraph->n_nodes);
|
||||
}
|
||||
|
||||
// Loop over nodes in GGML graph to determine if CUDA graph update is required
|
||||
// and store properties to allow this comparison for the next token
|
||||
std::unordered_set<ggml_tensor *> seen_node;
|
||||
std::vector<ggml_tensor *> srcs_extra;
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
bool props_match = true;
|
||||
|
||||
seen_node.insert(cgraph->nodes[i]);
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
|
||||
}
|
||||
|
|
@ -3001,17 +3016,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
|
|||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
|
||||
|
||||
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
|
||||
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
|
||||
if (src && seen_node.find(src) == seen_node.end()) {
|
||||
srcs_extra.push_back(src);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < cgraph->n_leafs; i++) {
|
||||
if (graph->extra.size() != (size_t) srcs_extra.size()) {
|
||||
res = true;
|
||||
graph->extra.resize(srcs_extra.size());
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < srcs_extra.size(); ++i) {
|
||||
bool props_match = true;
|
||||
|
||||
if (!res) {
|
||||
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]);
|
||||
props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
|
||||
}
|
||||
|
||||
if (!props_match) {
|
||||
res = true;
|
||||
}
|
||||
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
|
||||
ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
|
|
|||
|
|
@ -1630,11 +1630,6 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
|||
hparams.attn_soft_cap ? hparams.f_attn_logit_softcapping : 0.0f);
|
||||
cb(cur, LLAMA_TENSOR_NAME_FATTN, il);
|
||||
|
||||
if (!cparams.offload_kqv) {
|
||||
// all nodes between the KV store and the attention output are run on the CPU
|
||||
ggml_backend_sched_set_tensor_backend(sched, cur, backend_cpu);
|
||||
}
|
||||
|
||||
ggml_flash_attn_ext_add_sinks(cur, sinks);
|
||||
ggml_flash_attn_ext_set_prec (cur, GGML_PREC_F32);
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue