From c389dc9e8a8651de347b9936a9276044b790f9df Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Fri, 9 Jan 2026 21:02:18 +0800 Subject: [PATCH] models: rwkv7: use `build_ffn` Signed-off-by: Molly Sophia --- src/models/rwkv7-base.cpp | 16 ++++++++++++---- src/models/rwkv7.cpp | 2 +- 2 files changed, 13 insertions(+), 5 deletions(-) diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index 09bd944b7d..d975ced1f0 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -7,16 +7,24 @@ llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_ ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer, ggml_tensor * cur, ggml_tensor * x_prev, - llm_arch arch) const { + llm_arch arch, + int il) const { ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur); switch (arch) { case LLM_ARCH_RWKV7: { ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur); - ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk))); - - cur = build_lora_mm(layer->channel_mix_value, k); + cur = build_ffn( + xk, + layer->channel_mix_key, nullptr, nullptr, // up + nullptr, nullptr, nullptr, // gate + layer->channel_mix_value, nullptr, nullptr, // down + nullptr, + LLM_FFN_RELU_SQR, + LLM_FFN_SEQ, + il + ); } break; default: diff --git a/src/models/rwkv7.cpp b/src/models/rwkv7.cpp index 5caf6553df..4faafa9d64 100644 --- a/src/models/rwkv7.cpp +++ b/src/models/rwkv7.cpp @@ -66,7 +66,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids); x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids); } - cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7); + cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7, il); cur = ggml_add(ctx0, cur, ffn_inp); cur = build_cvec(cur, il);