models: rwkv7: use `build_ffn`

Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
Molly Sophia 2026-01-09 21:02:18 +08:00
parent f7b238d8ef
commit c389dc9e8a
2 changed files with 13 additions and 5 deletions

View File

@ -7,16 +7,24 @@ llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_
ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer,
ggml_tensor * cur,
ggml_tensor * x_prev,
llm_arch arch) const {
llm_arch arch,
int il) const {
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
switch (arch) {
case LLM_ARCH_RWKV7:
{
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk)));
cur = build_lora_mm(layer->channel_mix_value, k);
cur = build_ffn(
xk,
layer->channel_mix_key, nullptr, nullptr, // up
nullptr, nullptr, nullptr, // gate
layer->channel_mix_value, nullptr, nullptr, // down
nullptr,
LLM_FFN_RELU_SQR,
LLM_FFN_SEQ,
il
);
}
break;
default:

View File

@ -66,7 +66,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para
ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
}
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7, il);
cur = ggml_add(ctx0, cur, ffn_inp);
cur = build_cvec(cur, il);