This commit is contained in:
Xuan Son Nguyen 2026-01-29 17:04:39 +01:00
parent 8f1c6c02d5
commit be101fc117
11 changed files with 132 additions and 1 deletions

View File

@ -10931,6 +10931,15 @@ class LongcatFlashModel(DeepseekV2Model):
self.hparams["num_key_value_heads"] = self.hparams["num_attention_heads"]
self.hparams["intermediate_size"] = self.hparams["ffn_hidden_size"]
self.hparams["moe_intermediate_size"] = self.hparams["expert_ffn_hidden_size"]
self.hparams["num_experts_per_tok"] = self.hparams["moe_topk"]
def set_gguf_parameters(self):
super().set_gguf_parameters()
zero_expert_num = self.hparams["zero_expert_num"]
zero_expert_type = self.hparams["zero_expert_type"]
assert(zero_expert_type == "identity")
self.gguf_writer.add_n_zero_experts(zero_expert_num)
def modify_tensors(self, data_torch, name, bid):
if bid is not None:

View File

@ -148,6 +148,7 @@ class Keys:
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
N_ZERO_EXPERTS = "{arch}.n_zero_experts" # longcat-flash
class Attention:
HEAD_COUNT = "{arch}.attention.head_count"

View File

@ -1075,6 +1075,9 @@ class GGUFWriter:
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
def add_n_zero_experts(self, n: int) -> None:
self.add_uint32(Keys.LLM.N_ZERO_EXPERTS.format(arch=self.arch), n)
# for vision models
def add_clip_has_vision_encoder(self, value: bool) -> None:

View File

@ -120,6 +120,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_MIMO2, "mimo2" },
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
{ LLM_ARCH_MAINCODER, "maincoder" },
{ LLM_ARCH_LONGCAT_FLASH, "longcat-flash" },
{ LLM_ARCH_UNKNOWN, "(unknown)" },
};
@ -191,6 +192,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
{ LLM_KV_N_ZERO_EXPERTS, "%s.n_zero_experts" },
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
@ -1475,6 +1477,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
LLM_TENSOR_FFN_UP_SHEXP,
};
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_LONGCAT_FLASH:
return {
LLM_TENSOR_TOKEN_EMBD,
LLM_TENSOR_OUTPUT_NORM,

View File

@ -124,6 +124,7 @@ enum llm_arch {
LLM_ARCH_MIMO2,
LLM_ARCH_LLAMA_EMBED,
LLM_ARCH_MAINCODER,
LLM_ARCH_LONGCAT_FLASH,
LLM_ARCH_UNKNOWN,
};
@ -195,6 +196,7 @@ enum llm_kv {
LLM_KV_EMBEDDING_SCALE,
LLM_KV_TOKEN_SHIFT_COUNT,
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
LLM_KV_N_ZERO_EXPERTS,
LLM_KV_ATTENTION_HEAD_COUNT,
LLM_KV_ATTENTION_HEAD_COUNT_KV,

View File

@ -77,6 +77,7 @@ struct llama_hparams {
uint32_t n_expert_groups = 0;
uint32_t n_group_used = 0;
uint32_t n_group_experts = 0;
uint32_t n_zero_experts = 0;
float expert_group_scale = 0.05f;
float expert_weights_scale = 0.0f;

View File

@ -857,6 +857,8 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx
n_created++;
}
loaded_tensor_names.insert(name);
return tensor;
}
@ -886,11 +888,20 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
n_created++;
loaded_tensor_names.insert(name);
return tensor;
}
void llama_model_loader::done_getting_tensors() const {
if (n_created != n_tensors) {
// for debugging
for (const auto & it : weights_map) {
const std::string & name = it.first;
if (loaded_tensor_names.find(name) == loaded_tensor_names.end()) {
LLAMA_LOG_DEBUG("%s: tensor '%s' was not created\n", __func__, name.c_str());
}
}
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
}
}

View File

@ -10,6 +10,7 @@
#include <cstddef>
#include <map>
#include <set>
#include <stdexcept>
#include <unordered_map>
@ -94,6 +95,8 @@ struct llama_model_loader {
size_t size_data = 0;
std::vector<std::pair<size_t, size_t>> mmaps_used;
std::set<std::string> loaded_tensor_names; // for debugging
llama_model_loader(
const std::string & fname,
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme

View File

@ -1695,6 +1695,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
}
} break;
case LLM_ARCH_DEEPSEEK2:
case LLM_ARCH_LONGCAT_FLASH:
{
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
@ -1733,6 +1734,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
// (optional) n_zero_experts - used by longcat-flash
ml.get_key(LLM_KV_N_ZERO_EXPERTS, hparams.n_zero_experts, false);
hparams.f_attn_temp_offset = 0.0f;
switch (hparams.n_layer) {
@ -6971,6 +6975,83 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_LONGCAT_FLASH:
{
const bool is_mla = hparams.is_mla();
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
const int64_t n_embd_head_qk_rope = hparams.n_rot;
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
const int64_t q_lora_rank = hparams.n_lora_q;
const int64_t kv_lora_rank = hparams.n_lora_kv;
const int64_t n_ff_exp = hparams.n_ff_exp;
const int64_t n_expert_full = n_expert + hparams.n_zero_experts;
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
// try to load output.weight, if not found, use token_embd (tied embeddings)
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (!output) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
}
if (!is_mla) { throw std::runtime_error("mla is required"); }
if (q_lora_rank <= 0) { throw std::runtime_error("q_lora_rank must be > 0"); }
if (n_expert == 0) { throw std::runtime_error("n_expert must be > 0"); }
if (n_expert_used == 0) { throw std::runtime_error("n_expert_used must be > 0"); }
// NOTE: large part of the code is copied from deepseek2
// main difference is that longcat has zero experts and not all layers are MoE
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0);
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0);
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0);
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0);
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
// try to see if this is a dense or MoE layer
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert_full}, TENSOR_NOT_REQUIRED);
if (!layer.ffn_gate_inp) {
// dense
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
} else {
// MoE
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert_full}, TENSOR_NOT_REQUIRED);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
// shared experts
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -7311,7 +7392,7 @@ void llama_model::print_info() const {
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
}
if (arch == LLM_ARCH_DEEPSEEK2) {
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_LONGCAT_FLASH) {
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
@ -8268,6 +8349,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_MISTRAL3:
case LLM_ARCH_LLAMA_EMBED:
case LLM_ARCH_MAINCODER:
case LLM_ARCH_LONGCAT_FLASH:
return LLAMA_ROPE_TYPE_NORM;
// the pairs of head values are offset by n_rot/2

View File

@ -468,6 +468,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
};
break;
case LLAMA_VOCAB_PRE_TYPE_LONGCAT_FLASH:
regex_exprs = {
// original regex from tokenizer.json
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
// " ?[---‟ -。《》「」【】]+"
" ?[\uff01-\uff0f\uff1a-\uff5e'-\u201f\u3000-\u3002\u300a\u300b\u300c\u300d\u3010\u3011]+",
// "[一-龥ࠀ-一가-퟿]+"
"[\u4e00-\u9fa5\u0800-\u4e00\uac00-\ud7ff]+",
};
break;
default:
// default regex for BPE tokenization pre-processing
regex_exprs = {
@ -2041,6 +2052,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
tokenizer_pre == "solar-open") {
pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN;
clean_spaces = false;
} else if (
tokenizer_pre == "longcat-flash") {
pre_type = LLAMA_VOCAB_PRE_TYPE_LONGCAT_FLASH;
clean_spaces = false;
} else {
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
}

View File

@ -54,6 +54,7 @@ enum llama_vocab_pre_type {
LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43,
LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45,
LLAMA_VOCAB_PRE_TYPE_LONGCAT_FLASH = 46,
};
struct LLM_KV;