llama-fit: fix regex pattern for gate_up tensors (#20910)

* llama-fit: fix regex pattern for gate_up tensors

* Apply suggestions from code review

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>

---------

Co-authored-by: Johannes Gäßler <johannesg@5d6.de>
This commit is contained in:
Aman Gupta 2026-03-24 12:57:57 +08:00 committed by GitHub
parent 312d870a89
commit e852eb4901
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 4 deletions

View File

@ -365,14 +365,14 @@ static void llama_params_fit_impl(
case LAYER_FRACTION_ATTN: {
static std::array<std::string, n_strings> patterns;
if (patterns[il].empty()) {
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|gate|down).*";
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|up|gate_up|down).*";
}
return patterns[il].c_str();
}
case LAYER_FRACTION_UP: {
static std::array<std::string, n_strings> patterns;
if (patterns[il].empty()) {
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|down).*";
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(gate|gate_up|down).*";
}
return patterns[il].c_str();
}
@ -386,7 +386,7 @@ static void llama_params_fit_impl(
case LAYER_FRACTION_MOE: {
static std::array<std::string, n_strings> patterns;
if (patterns[il].empty()) {
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate)_(ch|)exps";
patterns[il] = "blk\\." + std::to_string(il) + "\\.ffn_(up|down|gate_up|gate)_(ch|)exps";
}
return patterns[il].c_str();
}
@ -480,7 +480,7 @@ static void llama_params_fit_impl(
int64_t global_surplus_cpu_moe = 0;
if (hp_nex > 0) {
const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate)_(ch|)exps"; // matches all MoE tensors
const static std::string pattern_moe_all = "blk\\.\\d+\\.ffn_(up|down|gate_up|gate)_(ch|)exps"; // matches all MoE tensors
ggml_backend_buffer_type_t cpu_buft = ggml_backend_cpu_buffer_type();
tensor_buft_overrides[0] = {pattern_moe_all.c_str(), cpu_buft};
tensor_buft_overrides[1] = {nullptr, nullptr};