models: rwkv7: use `build_ffn`
Signed-off-by: Molly Sophia <mollysophia379@gmail.com>
This commit is contained in:
parent
f7b238d8ef
commit
c389dc9e8a
|
|
@ -7,16 +7,24 @@ llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_
|
|||
ggml_tensor * llm_build_rwkv7_base::build_rwkv7_channel_mix(const llama_layer * layer,
|
||||
ggml_tensor * cur,
|
||||
ggml_tensor * x_prev,
|
||||
llm_arch arch) const {
|
||||
llm_arch arch,
|
||||
int il) const {
|
||||
ggml_tensor * sx = ggml_sub(ctx0, x_prev, cur);
|
||||
switch (arch) {
|
||||
case LLM_ARCH_RWKV7:
|
||||
{
|
||||
ggml_tensor * xk = ggml_add(ctx0, ggml_mul(ctx0, sx, layer->channel_mix_lerp_k), cur);
|
||||
|
||||
ggml_tensor * k = ggml_sqr(ctx0, ggml_relu(ctx0, build_lora_mm(layer->channel_mix_key, xk)));
|
||||
|
||||
cur = build_lora_mm(layer->channel_mix_value, k);
|
||||
cur = build_ffn(
|
||||
xk,
|
||||
layer->channel_mix_key, nullptr, nullptr, // up
|
||||
nullptr, nullptr, nullptr, // gate
|
||||
layer->channel_mix_value, nullptr, nullptr, // down
|
||||
nullptr,
|
||||
LLM_FFN_RELU_SQR,
|
||||
LLM_FFN_SEQ,
|
||||
il
|
||||
);
|
||||
}
|
||||
break;
|
||||
default:
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ llm_build_rwkv7::llm_build_rwkv7(const llama_model & model, const llm_graph_para
|
|||
ffn_norm = ggml_get_rows(ctx0, ffn_norm, inp_out_ids);
|
||||
x_prev = ggml_get_rows(ctx0, x_prev, inp_out_ids);
|
||||
}
|
||||
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7);
|
||||
cur = build_rwkv7_channel_mix(layer, ffn_norm, x_prev, LLM_ARCH_RWKV7, il);
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
|
|
|
|||
Loading…
Reference in New Issue