From ec2443a94abbd0f7b85fcdd1efbfdae1b32a87da Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 10 Mar 2026 10:50:41 +0200 Subject: [PATCH] llama : enable chunked fused GDN path --- ggml/src/ggml-cuda/ggml-cuda.cu | 3 +- src/llama-context.cpp | 95 ++++++++++++++++++++++++--------- src/llama-cparams.h | 1 + src/llama-impl.h | 6 +-- src/models/delta-net-base.cpp | 21 ++++++-- 5 files changed, 93 insertions(+), 33 deletions(-) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index cda275b8c5..3c36398647 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -4999,7 +4999,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g #ifdef GGML_USE_MUSA return false; #else - return true; + // TODO: add chunked support + return op->src[0]->ne[2] == 1; #endif // GGML_USE_MUSA case GGML_OP_FLASH_ATTN_EXT: return ggml_cuda_flash_attn_ext_supported(dev_ctx->device, op); diff --git a/src/llama-context.cpp b/src/llama-context.cpp index ee2669c154..6a7df97bec 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -151,7 +151,8 @@ llama_context::llama_context( cparams.auto_fa = params.flash_attn_type == LLAMA_FLASH_ATTN_TYPE_AUTO; cparams.fused_gdn_ar = true; - cparams.fused_gdn_ch = false; // TODO: implement + cparams.fused_gdn_ch = true; + cparams.auto_fgdn = true; // with causal attention, the batch size is limited by the context size cparams.n_batch = cparams.causal_attn ? std::min(cparams.n_ctx, params.n_batch) : params.n_batch; @@ -462,37 +463,81 @@ void llama_context::sched_reserve() { cparams.auto_fa = false; } - if (cparams.fused_gdn_ar) { - auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); - if (!gf) { - throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check"); - } + if (cparams.auto_fgdn) { + LLAMA_LOG_INFO("%s: resolving fused Gated Delta Net\n", __func__); - const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDNAR) + 1; - bool gdn_device_mismatch = false; - for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { - ggml_tensor * n = ggml_graph_node(gf, i); - if (n->op != GGML_OP_GATED_DELTA_NET) { - continue; + if (cparams.fused_gdn_ar) { + auto * gf = graph_reserve(1, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (autoregressive)"); } - ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); - GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDNAR "-", prefix_len) == 0); - const int il = std::stoi(n->name + prefix_len); - ggml_backend_dev_t device_kv = model.dev_layer(il); - if (device_gdn != device_kv) { - LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " - "is assigned to device %s (usually due to missing support)\n", - __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); - gdn_device_mismatch = true; - break; + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_AR) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_AR "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ar = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (autoregressive) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (autoregressive) enabled\n", __func__); } } - if (gdn_device_mismatch) { - cparams.fused_gdn_ar = false; - LLAMA_LOG_WARN("%s: fused Gated Delta Net not supported, set to disabled\n", __func__); + if (cparams.fused_gdn_ch) { + // more than one token in the batch per sequence in order to take the chunked path + auto * gf = graph_reserve(16*n_seqs, n_seqs, n_outputs, mctx.get(), true); + if (!gf) { + throw std::runtime_error("failed to reserve graph for fused Gated Delta Net check (chunked)"); + } + + const size_t prefix_len = strlen(LLAMA_TENSOR_NAME_FGDN_CH) + 1; + bool gdn_device_mismatch = false; + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + ggml_tensor * n = ggml_graph_node(gf, i); + if (n->op != GGML_OP_GATED_DELTA_NET) { + continue; + } + ggml_backend_dev_t device_gdn = ggml_backend_get_device(ggml_backend_sched_get_tensor_backend(sched.get(), n)); + + GGML_ASSERT(strncmp(n->name, LLAMA_TENSOR_NAME_FGDN_CH "-", prefix_len) == 0); + const int il = std::stoi(n->name + prefix_len); + ggml_backend_dev_t device_kv = model.dev_layer(il); + if (device_gdn != device_kv) { + LLAMA_LOG_WARN("%s: layer %d is assigned to device %s but the fused Gated Delta Net tensor " + "is assigned to device %s (usually due to missing support)\n", + __func__, il, ggml_backend_dev_name(device_kv), ggml_backend_dev_name(device_gdn)); + gdn_device_mismatch = true; + break; + } + } + + if (gdn_device_mismatch) { + cparams.fused_gdn_ch = false; + LLAMA_LOG_WARN("%s: fused Gated Delta Net (chunked) not supported, set to disabled\n", __func__); + } else { + LLAMA_LOG_INFO("%s: fused Gated Delta Net (chunked) enabled\n", __func__); + } } + + cparams.auto_fgdn = false; } // reserve worst-case graph diff --git a/src/llama-cparams.h b/src/llama-cparams.h index 333922468c..9d35947413 100644 --- a/src/llama-cparams.h +++ b/src/llama-cparams.h @@ -33,6 +33,7 @@ struct llama_cparams { bool auto_fa; bool fused_gdn_ar; // use fused gated delta net (autoregressive) bool fused_gdn_ch; // use fused gated delta net (chunked) + bool auto_fgdn; bool no_perf; bool warmup; bool op_offload; diff --git a/src/llama-impl.h b/src/llama-impl.h index ee27ac1bea..e4f35c8e53 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -70,6 +70,6 @@ std::string llama_format_tensor_shape(const struct ggml_tensor * t); std::string gguf_kv_to_str(const struct gguf_context * ctx_gguf, int i); -#define LLAMA_TENSOR_NAME_FATTN "__fattn__" -#define LLAMA_TENSOR_NAME_FGDNAR "__fgdnar__" -#define LLAMA_TENSOR_NAME_FGDNCH "__fgdnch__" +#define LLAMA_TENSOR_NAME_FATTN "__fattn__" +#define LLAMA_TENSOR_NAME_FGDN_AR "__fgdn_ar__" +#define LLAMA_TENSOR_NAME_FGDN_CH "__fgdn_ch__" diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp index b0be62fc68..940799df04 100644 --- a/src/models/delta-net-base.cpp +++ b/src/models/delta-net-base.cpp @@ -42,10 +42,23 @@ std::pair llm_build_delta_net_base::build_delta_ne GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); if (cparams.fused_gdn_ch) { - //ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); - //cb(result, LLAMA_TENSOR_NAME_FGDNCH, il); + ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); + cb(result, LLAMA_TENSOR_NAME_FGDN_CH, il); - GGML_ABORT("not implemented yet"); + ggml_tensor * output = ggml_view_4d(ctx0, result, + S_v, H_v, n_tokens, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens), 0); + + ggml_tensor * new_state = ggml_view_4d(ctx0, result, + S_v, S_v, H_v, n_seqs, + ggml_row_size(result->type, S_v), + ggml_row_size(result->type, S_v * S_v), + ggml_row_size(result->type, S_v * S_v * H_v), + ggml_row_size(result->type, S_v * H_v * n_tokens * n_seqs)); + + return {output, new_state}; } const float scale = 1.0f / sqrtf(S_k); @@ -327,7 +340,7 @@ std::pair llm_build_delta_net_base::build_delta_ne if (cparams.fused_gdn_ar) { ggml_tensor * result = ggml_gated_delta_net(ctx0, q, k, v, g, b, s); - cb(result, LLAMA_TENSOR_NAME_FGDNAR, il); + cb(result, LLAMA_TENSOR_NAME_FGDN_AR, il); ggml_tensor * output = ggml_view_4d(ctx0, result, S_v, H_v, n_tokens, n_seqs,