refactor softplus fn

This commit is contained in:
Xuan Son Nguyen 2025-09-20 12:17:15 +07:00
parent 46110e0630
commit f643b957f4
1 changed files with 10 additions and 8 deletions

View File

@ -19181,16 +19181,10 @@ private:
GGML_ASSERT(ggml_nelements(beta) + ggml_nelements(alpha) == ggml_nelements(mixed_ba));
// Softplus would be nice...
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); // a + dt_bias
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
ggml_tensor * one_tensor = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, 1); // Create scalar tensor
ggml_exp(ctx0, one_tensor); // make it a 1
ggml_tensor * one_plus_exp = ggml_add1(ctx0, alpha_exp, one_tensor); // 1 + exp(a + dt_bias)
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
ggml_tensor * alpha_softplus = softplus(alpha, model.layers[il].ssm_dt);
ggml_tensor * A_log_exp = ggml_exp(ctx0, model.layers[il].ssm_a); // A_log.exp()
ggml_tensor * gate_scaled = ggml_mul(ctx0, alpha_softplus, A_log_exp); // A_log.exp() * softplus
ggml_tensor * gate = ggml_neg(ctx0, gate_scaled); // - (A_log.exp() * softplus)
ggml_tensor * gate = ggml_scale(ctx0, gate_scaled, -1.0f); // - (A_log.exp() * softplus)
// Get convolution weights and bias
ggml_tensor * conv_weight = model.layers[il].ssm_conv1d;
@ -19324,6 +19318,14 @@ private:
return cur;
}
ggml_tensor * softplus(ggml_tensor * alpha, ggml_tensor * dt_bias) {
ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, dt_bias); // a + dt_bias
ggml_tensor * alpha_exp = ggml_exp(ctx0, alpha_biased); // exp(a + dt_bias)
ggml_tensor * one_plus_exp = ggml_scale_bias(ctx0, alpha_exp, 1.0f, 1.0f); // 1 + exp(a + dt_bias)
ggml_tensor * alpha_softplus = ggml_log(ctx0, one_plus_exp); // log(1 + exp(...))
return alpha_softplus;
}
};