diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index b69583dd3f..1988d16dc4 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -630,10 +630,11 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed }; enum ggml_tri_type { @@ -2577,11 +2578,42 @@ extern "C" { struct ggml_tensor * grad, struct ggml_tensor * sgd_params); // alpha, weight decay + // build forward mutiple tensors and select one of them for computing + // this is useful for creating graphs that have constant topology but compute different things based on the input + // ref: https://github.com/ggml-org/llama.cpp/pull/18550 // - // automatic differentiation + // nodes: + // | - build forward into the graph but do not compute + // c - build forward into the graph and compute // + // | | ... c ... | + // | | ... c ... | + // | | ... c ... | + // [0 1 ... idx ... n-1] <-- ggml_build_forward_select(..., n, idx) + // c + // c + // + // example: + // struct ggml_tensor * curs[3]; + // + // curs[0] = compute0(...); + // curs[1] = compute1(...); + // curs[2] = compute2(...); + // + // int idx = select_branch(some_input); + // + // struct ggml_tensor * out = ggml_build_forward_select(cgraph, curs, 3, idx); + // + GGML_API struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx); + + GGML_API void ggml_build_forward_expand( + struct ggml_cgraph * cgraph, + struct ggml_tensor * tensor); - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); GGML_API void ggml_build_backward_expand( struct ggml_context * ctx, // context for gradient computation struct ggml_cgraph * cgraph, @@ -2613,7 +2645,7 @@ extern "C" { GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index 1b59924b8c..354876574a 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -874,9 +874,9 @@ static void ggml_backend_sched_print_assignments(ggml_backend_sched_t sched, str } if (sched->debug > 1) { ggml_backend_t tensor_backend = ggml_backend_sched_get_tensor_backend(sched, node); - GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d:", i, ggml_op_name(node->op), node->name, + GGML_LOG_DEBUG("node #%3d (%10.10s): %20.20s (%5.5s) [%5.5s %8.8s] use=%d,c=%d:", i, ggml_op_name(node->op), node->name, fmt_size(ggml_nbytes(node)), tensor_backend ? ggml_backend_name(tensor_backend) : "NULL", GET_CAUSE(node), - graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)]); + graph->use_counts[ggml_hash_find(&graph->visited_hash_set, node)], node->flags & GGML_TENSOR_FLAG_COMPUTE ? 1 : 0); for (int j = 0; j < GGML_MAX_SRC; j++) { struct ggml_tensor * src = node->src[j]; if (src == NULL) { @@ -1922,6 +1922,7 @@ static struct ggml_tensor * graph_copy_dup_tensor(struct ggml_hash_set hash_set, dst->view_offs = src->view_offs; } dst->op = src->op; + dst->flags = src->flags; memcpy(dst->op_params, src->op_params, sizeof(dst->op_params)); ggml_set_name(dst, src->name); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 84956cbb9c..2e9ddf2240 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -226,6 +226,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_backend_blas_mul_mat(ctx, node); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index eba83327f1..42c6c67a40 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2146,6 +2146,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_cann_compute_forward(*cann_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f7ba1fe317..4c7a75e768 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,6 +2943,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index eaaf87612d..179522d835 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1123,6 +1123,7 @@ struct ggml_tensor_extra_gpu { struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; + int32_t flags; int64_t ne[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS]; void * src_address[GGML_MAX_SRC]; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ed1021469a..cda422defb 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2918,6 +2918,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { props->node_address = node->data; props->node_op = node->op; + props->flags = node->flags; for (int i = 0; i < GGML_MAX_DIMS; i++) { props->ne[i] = node->ne[i]; props->nb[i] = node->nb[i]; @@ -2961,6 +2962,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_ return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) != (props->flags & GGML_TENSOR_FLAG_COMPUTE)) { + return false; + } + return true; } @@ -3378,6 +3383,9 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index cf1eb994c3..5b835c11c7 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2497,6 +2497,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + uint32_t flags = 0; // skip quantizer if src1 is reused diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index 80e0fd2ff8..baadfe9a7b 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -611,6 +611,9 @@ static inline bool ggml_can_fuse_ext(const struct ggml_cgraph * cgraph, const in if (node->op != ops[i]) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } if (i < num_ops - 1 && !ggml_node_has_n_uses(cgraph, node_idxs[i], 1)) { return false; } diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 680ad794de..3d97d3dfdc 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported op"); } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + int n_fuse = 1; // check if the current node can run concurrently with other nodes before it diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index d89d5e7242..8059240b1c 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -3058,6 +3058,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); i += 2; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index 8f8176b678..bb8acc922b 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4109,6 +4109,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 0fabbcec31..08fd044ca0 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -12191,6 +12191,9 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr if (ggml_is_empty(node) || ggml_op_is_empty(node->op) || !node->buffer) { return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); ctx->semaphore_idx = 0; @@ -13645,7 +13648,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg int last_node = cgraph->n_nodes - 1; // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly - while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + while (last_node > 0 && (ggml_vk_is_empty(cgraph->nodes[last_node]) || ((cgraph->nodes[last_node]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0))) { last_node -= 1; } diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index 1470378af0..584cea7698 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1982,6 +1982,9 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, if (ggml_is_empty(node)) { return std::nullopt; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return std::nullopt; + } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index edbeb8eef2..906d25417e 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_zdnn_compute_forward(ctx, node); if (!ok) { GGML_LOG_ERROR("%s: unsupported op %s (%s)\n", diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fd07f983da..afbecde7a5 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -211,6 +211,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c75fe7d271..1725ad1654 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3441,7 +3441,8 @@ struct ggml_tensor * ggml_cast( result->op = GGML_OP_CPY; result->src[0] = a; - result->src[1] = result; + result->src[1] = result; // note: this self-reference might seem redundant, but it's actually needed by some + // backends for consistency with ggml_cpy_impl() above return result; } @@ -6725,20 +6726,35 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - // check if already visited - size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); +static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) { + if (node->op != GGML_OP_NONE && compute) { + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + } + + const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { - // This is the first time we see this node in the current graph. - cgraph->visited_hash_set.keys[node_hash_pos] = node; - ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); - cgraph->use_counts[node_hash_pos] = 0; - } else { + + if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { // already visited + + if (compute) { + // update the compute flag regardless + for (int i = 0; i < GGML_MAX_SRC; ++i) { + struct ggml_tensor * src = node->src[i]; + if (src && ((src->flags & GGML_TENSOR_FLAG_COMPUTE) == 0)) { + ggml_visit_parents_graph(cgraph, src, true); + } + } + } + return node_hash_pos; } + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : @@ -6747,7 +6763,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor struct ggml_tensor * src = node->src[k]; if (src) { - size_t src_hash_pos = ggml_visit_parents(cgraph, src); + const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute); // Update the use count for this operand. cgraph->use_counts[src_hash_pos]++; @@ -6778,17 +6794,17 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor return node_hash_pos; } -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) { if (!expand) { // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand ggml_graph_clear(cgraph); } - const int n0 = cgraph->n_nodes; + const int n_old = cgraph->n_nodes; - ggml_visit_parents(cgraph, tensor); + ggml_visit_parents_graph(cgraph, tensor, compute); - const int n_new = cgraph->n_nodes - n0; + const int n_new = cgraph->n_nodes - n_old; GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); if (n_new > 0) { @@ -6797,8 +6813,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx) { + GGML_ASSERT(idx >= 0 && idx < n_tensors); + + for (int i = 0; i < n_tensors; i++) { + ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false); + } + + return tensors[idx]; +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); + ggml_build_forward_impl(cgraph, tensor, true, true); } void ggml_build_backward_expand( @@ -7229,6 +7259,10 @@ bool ggml_can_fuse_subgraph_ext(const struct ggml_cgraph * cgraph, return false; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return false; + } + if (ggml_node_list_find_tensor(cgraph, outputs, num_outputs, node) != -1) { continue; } @@ -7310,7 +7344,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, label); } -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) { char color[16]; FILE * fp = ggml_fopen(filename, "w"); @@ -7331,7 +7365,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); } else if (grad) { - if (ggml_graph_find(gf, node)) { + if (ggml_graph_find(cgraph, node)) { snprintf(color, sizeof(color), "green"); } else { snprintf(color, sizeof(color), "lightblue");