From 4e010b4d7b374490239ccaae98e493d34c2f3743 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Fri, 2 Jan 2026 18:28:24 +0200 Subject: [PATCH] ggml : add ggml_build_forward_select --- ggml/include/ggml.h | 22 +++++--- ggml/src/ggml-blas/ggml-blas.cpp | 4 ++ ggml/src/ggml-cann/ggml-cann.cpp | 4 ++ ggml/src/ggml-cpu/ggml-cpu.c | 4 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 3 ++ ggml/src/ggml-hexagon/ggml-hexagon.cpp | 4 ++ ggml/src/ggml-metal/ggml-metal-ops.cpp | 4 ++ ggml/src/ggml-opencl/ggml-opencl.cpp | 4 ++ ggml/src/ggml-sycl/ggml-sycl.cpp | 3 ++ ggml/src/ggml-vulkan/ggml-vulkan.cpp | 4 ++ ggml/src/ggml-webgpu/ggml-webgpu.cpp | 3 ++ ggml/src/ggml-zdnn/ggml-zdnn.cpp | 4 ++ ggml/src/ggml-zendnn/ggml-zendnn.cpp | 4 ++ ggml/src/ggml.c | 71 ++++++++++++++++++++------ src/llama-graph.cpp | 32 +++++++----- src/llama-graph.h | 4 +- src/models/gemma3n-iswa.cpp | 2 +- 17 files changed, 139 insertions(+), 37 deletions(-) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 20c912d0e9..5350178620 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -625,10 +625,11 @@ extern "C" { // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph - GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph - GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters - GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) + GGML_TENSOR_FLAG_COMPUTE = 16, // ...must be computed }; enum ggml_tri_type { @@ -2576,7 +2577,16 @@ extern "C" { // automatic differentiation // - GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx); + + GGML_API void ggml_build_forward_expand( + struct ggml_cgraph * cgraph, + struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand( struct ggml_context * ctx, // context for gradient computation struct ggml_cgraph * cgraph, @@ -2608,7 +2618,7 @@ extern "C" { GGML_API void ggml_graph_print(const struct ggml_cgraph * cgraph); // dump the graph into a file using the dot format - GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); + GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename); // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); diff --git a/ggml/src/ggml-blas/ggml-blas.cpp b/ggml/src/ggml-blas/ggml-blas.cpp index 5b888cdd8c..a32d68ea87 100644 --- a/ggml/src/ggml-blas/ggml-blas.cpp +++ b/ggml/src/ggml-blas/ggml-blas.cpp @@ -230,6 +230,10 @@ static enum ggml_status ggml_backend_blas_graph_compute(ggml_backend_t backend, for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_backend_blas_mul_mat(ctx, node); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index ef23ec78da..0cc0351a17 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -2110,6 +2110,10 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_cann_compute_forward(*cann_ctx, node); if (!ok) { GGML_LOG_ERROR("%s: op not supported %s (%s)\n", __func__, node->name, ggml_op_name(node->op)); diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index f7ba1fe317..4c7a75e768 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2943,6 +2943,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + ggml_compute_forward(¶ms, node); if (state->ith == 0 && cplan->abort_callback && diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 84eccea3f7..ddb168a5e9 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -3350,6 +3350,9 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } // start of fusion operations static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 13b96d61f8..168ae97730 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -2379,6 +2379,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + uint32_t flags = 0; // skip quantizer if src1 is reused diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index a50b12b6f3..0a7cd3a36a 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -203,6 +203,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_ABORT("unsupported op"); } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return 1; + } + int n_fuse = 1; // check if the current node can run concurrently with other nodes before it diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 353f6a4b46..53dfbc58f0 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2967,6 +2967,10 @@ static ggml_status ggml_backend_opencl_graph_compute(ggml_backend_t backend, ggm continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + if (!backend_ctx->disable_fusion && ggml_opencl_can_fuse(cgraph, i, { GGML_OP_NORM, GGML_OP_MUL, GGML_OP_ADD })) { ggml_opencl_op_norm_fused(backend, node, cgraph->nodes[i+1], cgraph->nodes[i+2]); i += 2; diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp index e996d98be8..ef27bdd5c4 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp @@ -4109,6 +4109,9 @@ static void ggml_backend_sycl_graph_compute_impl(ggml_backend_sycl_context * syc if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } #ifndef NDEBUG assert(node->buffer->buft == ggml_backend_sycl_buffer_type(sycl_ctx->device)); for (int j = 0; j < GGML_MAX_SRC; j++) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 541e4a50b7..69ccf539aa 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -13443,6 +13443,10 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg submit_node_idx = i; } + if ((cgraph->nodes[i]->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { auto bytes = ggml_nbytes(cgraph->nodes[i]->src[0]); mul_mat_bytes += bytes; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d0e99b6fe2..b4c3a4d851 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -1374,6 +1374,9 @@ static std::optional ggml_webgpu_encode_node(webgpu_context ctx, if (ggml_is_empty(node)) { return std::nullopt; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + return std::nullopt; + } WEBGPU_LOG_DEBUG("ggml_webgpu_encode_node(" << node << ", " << ggml_op_name(node->op) << ")"); ggml_tensor * src0 = node->src[0]; diff --git a/ggml/src/ggml-zdnn/ggml-zdnn.cpp b/ggml/src/ggml-zdnn/ggml-zdnn.cpp index edbeb8eef2..906d25417e 100644 --- a/ggml/src/ggml-zdnn/ggml-zdnn.cpp +++ b/ggml/src/ggml-zdnn/ggml-zdnn.cpp @@ -58,6 +58,10 @@ static enum ggml_status ggml_zdnn_graph_compute(ggml_backend_t backend, ggml_cgr continue; } + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + bool ok = ggml_zdnn_compute_forward(ctx, node); if (!ok) { GGML_LOG_ERROR("%s: unsupported op %s (%s)\n", diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index fd07f983da..afbecde7a5 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -211,6 +211,10 @@ static ggml_status ggml_backend_zendnn_graph_compute(ggml_backend_t backend, ggm for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * node = cgraph->nodes[i]; + if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) { + continue; + } + switch (node->op) { case GGML_OP_MUL_MAT: ggml_zendnn_compute_forward_mul_mat(ctx, node); diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb3ae72eaa..f57c38f799 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -6718,20 +6718,39 @@ static void ggml_compute_backward( GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } -static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - // check if already visited - size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); +static void ggml_visit_parents_compute(struct ggml_tensor * node) { + if (node->flags & GGML_TENSOR_FLAG_COMPUTE) { + return; + } + + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + + for (int i = 0; i < GGML_MAX_SRC; ++i) { + struct ggml_tensor * src = node->src[i]; + if (src) { + ggml_visit_parents_compute(src); + } + } +} + +static size_t ggml_visit_parents_graph(struct ggml_cgraph * cgraph, struct ggml_tensor * node, bool compute) { + if (compute) { + node->flags |= GGML_TENSOR_FLAG_COMPUTE; + } + + const size_t node_hash_pos = ggml_hash_find(&cgraph->visited_hash_set, node); GGML_ASSERT(node_hash_pos != GGML_HASHSET_FULL); - if (!ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { - // This is the first time we see this node in the current graph. - cgraph->visited_hash_set.keys[node_hash_pos] = node; - ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); - cgraph->use_counts[node_hash_pos] = 0; - } else { + + if (ggml_bitset_get(cgraph->visited_hash_set.used, node_hash_pos)) { // already visited return node_hash_pos; } + // This is the first time we see this node in the current graph. + cgraph->visited_hash_set.keys[node_hash_pos] = node; + ggml_bitset_set(cgraph->visited_hash_set.used, node_hash_pos); + cgraph->use_counts[node_hash_pos] = 0; + for (int i = 0; i < GGML_MAX_SRC; ++i) { const int k = (cgraph->order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? i : @@ -6740,7 +6759,7 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor struct ggml_tensor * src = node->src[k]; if (src) { - size_t src_hash_pos = ggml_visit_parents(cgraph, src); + const size_t src_hash_pos = ggml_visit_parents_graph(cgraph, src, compute); // Update the use count for this operand. cgraph->use_counts[src_hash_pos]++; @@ -6771,17 +6790,21 @@ static size_t ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor return node_hash_pos; } -static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand) { +static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor, bool expand, bool compute) { if (!expand) { // TODO: this branch isn't accessible anymore, maybe move this to ggml_build_forward_expand ggml_graph_clear(cgraph); } - const int n0 = cgraph->n_nodes; + const int n_old = cgraph->n_nodes; - ggml_visit_parents(cgraph, tensor); + ggml_visit_parents_graph(cgraph, tensor, compute); - const int n_new = cgraph->n_nodes - n0; + if (compute) { + ggml_visit_parents_compute(tensor); + } + + const int n_new = cgraph->n_nodes - n_old; GGML_PRINT_DEBUG("%s: visited %d new nodes\n", __func__, n_new); if (n_new > 0) { @@ -6790,8 +6813,22 @@ static void ggml_build_forward_impl(struct ggml_cgraph * cgraph, struct ggml_ten } } +struct ggml_tensor * ggml_build_forward_select( + struct ggml_cgraph * cgraph, + struct ggml_tensor ** tensors, + int n_tensors, + int idx) { + GGML_ASSERT(idx >= 0 && idx < n_tensors); + + for (int i = 0; i < n_tensors; i++) { + ggml_build_forward_impl(cgraph, tensors[i], true, i == idx ? true : false); + } + + return tensors[idx]; +} + void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor) { - ggml_build_forward_impl(cgraph, tensor, true); + ggml_build_forward_impl(cgraph, tensor, true, true); } void ggml_build_backward_expand( @@ -7303,7 +7340,7 @@ static void ggml_graph_dump_dot_leaf_edge(FILE * fp, struct ggml_tensor * node, label); } -void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename) { +void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * cgraph, const char * filename) { char color[16]; FILE * fp = ggml_fopen(filename, "w"); @@ -7324,7 +7361,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); } else if (grad) { - if (ggml_graph_find(gf, node)) { + if (ggml_graph_find(cgraph, node)) { snprintf(color, sizeof(color), "green"); } else { snprintf(color, sizeof(color), "lightblue"); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d0d7197e1..f3c09ab02f 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -21,7 +21,8 @@ void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { } if (ubatch->embd) { - const int64_t n_embd = embd->ne[0]; + GGML_ASSERT(n_embd == embd->ne[0]); + const int64_t n_tokens = ubatch->n_tokens; ggml_backend_tensor_set(embd, ubatch->embd, 0, n_tokens*n_embd*ggml_element_size(embd)); @@ -1206,16 +1207,24 @@ ggml_tensor * llm_graph_context::build_moe_ffn( ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { const int64_t n_embd = hparams.n_embd_inp(); - auto inp = std::make_unique(); + auto inp = std::make_unique(n_embd); - ggml_tensor * cur = nullptr; + inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); + cb(inp->tokens, "inp_tokens", -1); + ggml_set_input(inp->tokens); if (ubatch.token) { - inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); - //cb(inp->tokens, "inp_tokens", -1); - ggml_set_input(inp->tokens); res->t_tokens = inp->tokens; + } + inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); + cb(inp->embd, "inp_embd", -1); + ggml_set_input(inp->embd); + + ggml_tensor * cur; + + // token embeddings + { cur = ggml_get_rows(ctx0, tok_embd, inp->tokens); // apply lora for embedding tokens if needed @@ -1235,19 +1244,18 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { cur = ggml_add(ctx0, cur, inpL_delta); } - } else { - inp->embd = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_embd, ubatch.n_tokens); - ggml_set_input(inp->embd); - - cur = inp->embd; } + std::array inps = { cur, inp->embd }; + + cur = ggml_build_forward_select(gf, inps.data(), inps.size(), ubatch.token ? 0 : 1); + // For Granite architecture if (hparams.f_embedding_scale != 0.0f) { cur = ggml_scale(ctx0, cur, hparams.f_embedding_scale); } - cb(cur, "inp_embd", -1); + cb(cur, "embd", -1); res->add_input(std::move(inp)); diff --git a/src/llama-graph.h b/src/llama-graph.h index 81ac329cc3..07c81e79f9 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -104,7 +104,7 @@ using llm_graph_input_ptr = std::unique_ptr; class llm_graph_input_embd : public llm_graph_input_i { public: - llm_graph_input_embd() = default; + llm_graph_input_embd(int64_t n_embd) : n_embd(n_embd) {} virtual ~llm_graph_input_embd() = default; void set_input(const llama_ubatch * ubatch) override; @@ -113,6 +113,8 @@ public: ggml_tensor * tokens = nullptr; // I32 [n_batch] ggml_tensor * embd = nullptr; // F32 [n_embd, n_batch] + + const int64_t n_embd = 0; }; class llm_graph_input_pos : public llm_graph_input_i { diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp index 9c7b3ba0bb..15054cf1ce 100644 --- a/src/models/gemma3n-iswa.cpp +++ b/src/models/gemma3n-iswa.cpp @@ -245,7 +245,7 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { - auto inp = std::make_unique(); + auto inp = std::make_unique(n_embd); ggml_tensor * inp_per_layer; if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens);