Compare commits

...

3 Commits

Author SHA1 Message Date
Alan Gray c2691d968a disable for multi-gpu and batch size > 1 2024-04-22 09:17:54 -07:00
Alan Gray 800f4fe48e Tidied to now only use CUDA runtime (not mixed with driver calls) 2024-04-22 05:11:05 -07:00
Alan Gray c8dd0e7c1c FIx issues raised in comments 2024-04-22 03:47:24 -07:00
1 changed files with 48 additions and 22 deletions

View File

@ -2405,28 +2405,42 @@ GGML_CALL static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
GGML_UNUSED(backend);
}
#if (CUDART_VERSION >= 12000)
#define USE_CUDA_GRAPH
#endif
#ifdef USE_CUDA_GRAPH
#define MAX_NODES_IN_CUDA_GRAPH 10000
struct ggml_cudaGraph {
int count=0;
cudaGraph_t graph = nullptr;
cudaGraphExec_t instance = nullptr;
size_t numNodes = 0;
int softmax_ne0 = 0;
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
cudaKernelNodeParams params[MAX_NODES_IN_CUDA_GRAPH];
};
#endif
GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
#ifdef USE_CUDA_GRAPH
// Objects required for CUDA Graph
#define MAX_NODES_IN_CUDA_GRAPH 10000
static ggml_cudaGraph cudaGraph; //TO DO move this to a suitable persistant location (and avoid use of static memory)
bool useCudaGraph = (cudaGraph.count>=2); //avoid CUDA graphs on first 2 steps due to incompatible initialisations.
static ggml_cudaGraph cudaGraph;
bool useCudaGraph = (cudaGraph.count>=7); //avoid CUDA graphs on first few steps due to incompatible initialisations.
char** updatedKernelArg[MAX_NODES_IN_CUDA_GRAPH];
bool cudaGraphUpdateRequired = false;
// pointer to CUDA cpy kernel, which is required to identify
// kernel parameters which need updated in the graph for each token
void* ggmlCudaCpyFn = nullptr;
if(ggml_backend_cuda_get_device_count() > 1){
useCudaGraph = false; // disable CUDA graphs for multi-gpu for now. TO DO investigate
}
if(useCudaGraph) {
if(cudaGraph.instance == nullptr) cudaGraphUpdateRequired=true;
@ -2438,6 +2452,9 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// Identify if the graph needs updated for this token due to the number of elements changing
// (identified by inspecting soft max op parameters)
if(node->op == GGML_OP_SOFT_MAX) {
if(node->src[1]->ne[1] > 1){
useCudaGraph = false; // disable CUDA graphs for batch size > 1 for now. TO DO investigate
}
if(node->src[0]->ne[0] != cudaGraph.softmax_ne0) {
cudaGraphUpdateRequired = true;
cudaGraph.softmax_ne0 = node->src[0]->ne[0];
@ -2458,6 +2475,11 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeGlobal));
}
#else
bool useCudaGraph = false;
bool cudaGraphUpdateRequired = false;
#endif
// Only perfom the graph exection if CUDA graphs are not enebled, or we are capturing the graph.
// With use of CUDA graphs, the execution will be performed by the graph launch.
if(!useCudaGraph || cudaGraphUpdateRequired) {
@ -2486,6 +2508,7 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
}
}
#ifdef USE_CUDA_GRAPH
if(useCudaGraph && (cudaGraphUpdateRequired)) { // End CUDA graph capture
CUDA_CHECK(cudaStreamEndCapture(cuda_ctx->stream(), &cudaGraph.graph));
}
@ -2498,53 +2521,56 @@ GGML_CALL static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t
// Perform update to graph (if required for this token), and change copy parameter (required for every token)
cudaGraphNode_t nodes[MAX_NODES_IN_CUDA_GRAPH];
CUDA_KERNEL_NODE_PARAMS_v2 paramsDriver[MAX_NODES_IN_CUDA_GRAPH];
cudaKernelNodeParams paramsRuntime[MAX_NODES_IN_CUDA_GRAPH];
if(cudaGraphUpdateRequired) {
// Extract nodes from graph
if(cudaGraph.numNodes == 0) {
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nullptr, &cudaGraph.numNodes));
}
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, nodes, &cudaGraph.numNodes));
CUDA_CHECK(cudaGraphGetNodes(cudaGraph.graph, cudaGraph.nodes, &cudaGraph.numNodes));
// Loop over nodes, and extract kernel parameters fro each node
for(size_t i=0; i<cudaGraph.numNodes; i++) {
// We currently get a set of params using both driver and runtime, to work around an issue (see below)
CU_CHECK(cuGraphKernelNodeGetParams(nodes[i], &paramsDriver[i])); // Get params using driver
cudaError_t statRT = cudaGraphKernelNodeGetParams(nodes[i], &paramsRuntime[i]); // Get params using runtime
if(statRT == 98) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
cudaGetLastError();
cudaGraphNodeType nodeType;
CUDA_CHECK(cudaGraphNodeGetType(cudaGraph.nodes[i], &nodeType));
if (nodeType == cudaGraphNodeTypeKernel) {
auto statRT = cudaGraphKernelNodeGetParams(cudaGraph.nodes[i], &cudaGraph.params[i]); // Get params using runtime
if(statRT == cudaErrorInvalidDeviceFunction) {
// Fails due to incorrect handling by CUDA runtime of CUDA BLAS node.
// We don't need to update blas nodes, so clear error and move on.
cudaGetLastError();
}
}
}
}
// Update copy kernel param (required every token)
// Currently uses runtime copy of params to identify copy function node,
// and driver copy of params to perform the update
// TO DO work out how to do it only using runtime copy.
if(!cudaGraphUpdateRequired) { // on update steps, the live parameters will already be captured
int k=0;
for(size_t i=0; i<cudaGraph.numNodes; i++) {
if(paramsRuntime[i].func == ggmlCudaCpyFn) {
if(cudaGraph.params[i].func == ggmlCudaCpyFn) {
char** updatedKernelArgPointer = updatedKernelArg[k++];
paramsDriver[i].kernelParams[1] = updatedKernelArgPointer;
CU_CHECK(cuGraphKernelNodeSetParams(nodes[i], &paramsDriver[i]));
cudaGraph.params[i].kernelParams[1] = updatedKernelArgPointer;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cudaGraph.nodes[i], &cudaGraph.params[i]));
}
}
}
// Update graph executable
cudaGraphExecUpdateResultInfo resultInfo;
CUDA_CHECK(cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo));
auto stat = cudaGraphExecUpdate(cudaGraph.instance, cudaGraph.graph, &resultInfo);
if(stat == cudaErrorGraphExecUpdateFailure)
{
// The pre-existing graph exec cannot be updated due to violated constraints
// so instead clar error and re-instantiate
cudaGetLastError();
CUDA_CHECK(cudaGraphInstantiate(&cudaGraph.instance, cudaGraph.graph, NULL, NULL, 0));
}
// Launch graph
CUDA_CHECK(cudaGraphLaunch(cudaGraph.instance, cuda_ctx->stream()));
}
cudaGraph.count++;
#endif
return GGML_STATUS_SUCCESS;
}