Merge 1fa084e733 into 2634ed207a
This commit is contained in:
commit
6c6b0f079a
|
|
@ -1257,6 +1257,9 @@ class TextModel(ModelBase):
|
|||
if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f":
|
||||
# ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B
|
||||
res = "exaone-moe"
|
||||
if chkhsh == "27d87c17bcffe5262a1e80b2ceb9a5e002c4f8a17d796fd5afac9180dd8bd96e":
|
||||
# ref: https://huggingface.co/meituan-longcat/LongCat-Flash-Chat
|
||||
res = "longcat-flash"
|
||||
|
||||
if res is None:
|
||||
logger.warning("\n")
|
||||
|
|
@ -10915,6 +10918,61 @@ class SolarOpenModel(Glm4MoeModel):
|
|||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
@ModelBase.register("LongcatFlashForCausalLM")
|
||||
class LongcatFlashModel(DeepseekV2Model):
|
||||
model_arch = gguf.MODEL_ARCH.LONGCAT_FLASH
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# the model use double block, we need to adjust block count
|
||||
self.block_count = self.hparams["num_layers"] * 2
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
# compat with deepseek2 base class hparam
|
||||
self.hparams["num_hidden_layers"] = self.block_count
|
||||
self.hparams["num_key_value_heads"] = self.hparams["num_attention_heads"]
|
||||
self.hparams["intermediate_size"] = self.hparams["ffn_hidden_size"]
|
||||
self.hparams["moe_intermediate_size"] = self.hparams["expert_ffn_hidden_size"]
|
||||
self.hparams["num_experts_per_tok"] = self.hparams["moe_topk"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
zero_expert_num = self.hparams["zero_expert_num"]
|
||||
zero_expert_type = self.hparams["zero_expert_type"]
|
||||
assert zero_expert_type == "identity", "cpp implementation only supports 'identity' type"
|
||||
self.gguf_writer.add_n_zero_experts(zero_expert_num)
|
||||
|
||||
def modify_tensors(self, data_torch, name, bid):
|
||||
if bid is not None:
|
||||
bid = bid * 2 # double block id
|
||||
|
||||
# Rename rules examples:
|
||||
# model.layers.1.input_layernorm.0.weight --> model.layers.1.input_layernorm.weight
|
||||
# model.layers.1.input_layernorm.1.weight --> model.layers.2.input_layernorm.weight
|
||||
# model.layers.1.mlp.experts.0 --> model.layers.2.mlp.expert.0 (special case for experts)
|
||||
|
||||
name = name.replace('.mlps.', '.mlp.')
|
||||
name = name.replace('.router.classifier.', '.gate.')
|
||||
name = name.replace('.router.e_score_correction_bias', '.e_score_correction_bias')
|
||||
|
||||
# handle sub-block remapping
|
||||
match = re.match(r'.*\.(\d+)\.([a-z_\.]+)\.(\d+)\..*', name)
|
||||
if match and ".mlp.experts." not in name:
|
||||
# convert block id from N.(name).M to (N+M).(name)
|
||||
N = int(match.group(1))
|
||||
middle = match.group(2)
|
||||
M = int(match.group(3))
|
||||
assert(N * 2 == bid)
|
||||
new_bid = N * 2 + M
|
||||
new_name = re.sub(r'\.(\d+)\.([a-z_\.]+)\.(\d+)\.', f'.{new_bid}.{middle}.', name)
|
||||
yield from super().modify_tensors(data_torch, new_name, new_bid)
|
||||
else:
|
||||
# correct block inside name (fix for experts tensors)
|
||||
if bid is not None:
|
||||
name = name.replace(f'.{bid // 2}.', f'.{bid}.', 1)
|
||||
yield from super().modify_tensors(data_torch, name, bid)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ models = [
|
|||
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
|
||||
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
|
||||
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
|
||||
{"name": "longcat-flash", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/meituan-longcat/LongCat-Flash-Chat", },
|
||||
]
|
||||
|
||||
# some models are known to be broken upstream, so we will skip them as exceptions
|
||||
|
|
|
|||
|
|
@ -148,6 +148,7 @@ class Keys:
|
|||
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
|
||||
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
|
||||
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
|
||||
N_ZERO_EXPERTS = "{arch}.n_zero_experts" # longcat-flash
|
||||
|
||||
class Attention:
|
||||
HEAD_COUNT = "{arch}.attention.head_count"
|
||||
|
|
@ -459,6 +460,7 @@ class MODEL_ARCH(IntEnum):
|
|||
MIMO2 = auto()
|
||||
LLAMA_EMBED = auto()
|
||||
MAINCODER = auto()
|
||||
LONGCAT_FLASH = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -880,6 +882,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.MIMO2: "mimo2",
|
||||
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
|
||||
MODEL_ARCH.MAINCODER: "maincoder",
|
||||
MODEL_ARCH.LONGCAT_FLASH: "longcat-flash",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -3377,6 +3380,36 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.LONGCAT_FLASH: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_A,
|
||||
MODEL_TENSOR.ATTN_Q_B,
|
||||
MODEL_TENSOR.ATTN_KV_A_MQA,
|
||||
MODEL_TENSOR.ATTN_KV_B,
|
||||
MODEL_TENSOR.ATTN_K_B,
|
||||
MODEL_TENSOR.ATTN_V_B,
|
||||
MODEL_TENSOR.ATTN_Q_A_NORM,
|
||||
MODEL_TENSOR.ATTN_KV_A_NORM,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_SHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP,
|
||||
MODEL_TENSOR.FFN_UP_SHEXP,
|
||||
MODEL_TENSOR.FFN_EXP_PROBS_B,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1075,6 +1075,9 @@ class GGUFWriter:
|
|||
def add_classifier_output_labels(self, labels: Sequence[str]) -> None:
|
||||
self.add_array(Keys.Classifier.OUTPUT_LABELS.format(arch=self.arch), labels)
|
||||
|
||||
def add_n_zero_experts(self, n: int) -> None:
|
||||
self.add_uint32(Keys.LLM.N_ZERO_EXPERTS.format(arch=self.arch), n)
|
||||
|
||||
# for vision models
|
||||
|
||||
def add_clip_has_vision_encoder(self, value: bool) -> None:
|
||||
|
|
|
|||
|
|
@ -89,6 +89,7 @@ add_library(llama
|
|||
models/llada.cpp
|
||||
models/llama-iswa.cpp
|
||||
models/llama.cpp
|
||||
models/longcat-flash.cpp
|
||||
models/maincoder.cpp
|
||||
models/mamba.cpp
|
||||
models/mimo2-iswa.cpp
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_MIMO2, "mimo2" },
|
||||
{ LLM_ARCH_LLAMA_EMBED, "llama-embed" },
|
||||
{ LLM_ARCH_MAINCODER, "maincoder" },
|
||||
{ LLM_ARCH_LONGCAT_FLASH, "longcat-flash" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -191,6 +192,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" },
|
||||
{ LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" },
|
||||
{ LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" },
|
||||
{ LLM_KV_N_ZERO_EXPERTS, "%s.n_zero_experts" },
|
||||
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" },
|
||||
{ LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" },
|
||||
|
|
@ -1475,6 +1477,7 @@ static std::set<llm_tensor> llm_get_tensor_names(llm_arch arch) {
|
|||
LLM_TENSOR_FFN_UP_SHEXP,
|
||||
};
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_LONGCAT_FLASH:
|
||||
return {
|
||||
LLM_TENSOR_TOKEN_EMBD,
|
||||
LLM_TENSOR_OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -124,6 +124,7 @@ enum llm_arch {
|
|||
LLM_ARCH_MIMO2,
|
||||
LLM_ARCH_LLAMA_EMBED,
|
||||
LLM_ARCH_MAINCODER,
|
||||
LLM_ARCH_LONGCAT_FLASH,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -195,6 +196,7 @@ enum llm_kv {
|
|||
LLM_KV_EMBEDDING_SCALE,
|
||||
LLM_KV_TOKEN_SHIFT_COUNT,
|
||||
LLM_KV_INTERLEAVE_MOE_LAYER_STEP,
|
||||
LLM_KV_N_ZERO_EXPERTS,
|
||||
|
||||
LLM_KV_ATTENTION_HEAD_COUNT,
|
||||
LLM_KV_ATTENTION_HEAD_COUNT_KV,
|
||||
|
|
|
|||
|
|
@ -1114,6 +1114,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
const int64_t n_tokens = cur->ne[1];
|
||||
const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN
|
||||
|
||||
// longcat-flash use n_zero_experts
|
||||
const int64_t n_probs = n_expert + hparams.n_zero_experts;
|
||||
|
||||
ggml_tensor * logits = nullptr;
|
||||
|
||||
if (probs_in == nullptr) {
|
||||
|
|
@ -1169,7 +1172,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
// select top n_group_used expert groups
|
||||
// https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/e815299b0bcbac849fa540c768ef21845365c9eb/modeling_deepseek.py#L440-L457
|
||||
if (hparams.n_expert_groups > 1 && n_tokens > 0) {
|
||||
const int64_t n_exp_per_group = n_expert / hparams.n_expert_groups;
|
||||
const int64_t n_exp_per_group = n_probs / hparams.n_expert_groups;
|
||||
|
||||
// organize experts into n_expert_groups
|
||||
ggml_tensor * selection_groups = ggml_reshape_3d(ctx0, selection_probs, n_exp_per_group, hparams.n_expert_groups, n_tokens); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||
|
|
@ -1187,7 +1190,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
// mask out the other groups
|
||||
selection_probs = ggml_get_rows(ctx0, selection_groups, expert_groups); // [n_exp_per_group, n_group_used, n_tokens]
|
||||
selection_probs = ggml_set_rows(ctx0, ggml_fill(ctx0, selection_groups, -INFINITY), selection_probs, expert_groups); // [n_exp_per_group, n_expert_groups, n_tokens]
|
||||
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_expert, n_tokens); // [n_expert, n_tokens]
|
||||
selection_probs = ggml_reshape_2d(ctx0, selection_probs, n_probs, n_tokens); // [n_probs, n_tokens]
|
||||
cb(selection_probs, "ffn_moe_probs_masked", il);
|
||||
}
|
||||
|
||||
|
|
@ -1201,6 +1204,12 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
|
||||
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
|
||||
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
|
||||
|
||||
} else if (arch == LLM_ARCH_LONGCAT_FLASH && hparams.n_zero_experts > 0) {
|
||||
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
|
||||
// TODO (hard): how to implement zero-computation experts here?
|
||||
probs = ggml_reshape_3d(ctx0, probs, 1, n_probs, n_tokens);
|
||||
|
||||
} else {
|
||||
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ struct llama_hparams {
|
|||
uint32_t n_expert_groups = 0;
|
||||
uint32_t n_group_used = 0;
|
||||
uint32_t n_group_experts = 0;
|
||||
uint32_t n_zero_experts = 0;
|
||||
|
||||
float expert_group_scale = 0.05f;
|
||||
float expert_weights_scale = 0.0f;
|
||||
|
|
|
|||
|
|
@ -857,6 +857,8 @@ struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx
|
|||
n_created++;
|
||||
}
|
||||
|
||||
loaded_tensor_names.insert(name);
|
||||
|
||||
return tensor;
|
||||
|
||||
}
|
||||
|
|
@ -886,11 +888,20 @@ struct ggml_tensor * llama_model_loader::create_tensor_as_view(struct ggml_conte
|
|||
|
||||
n_created++;
|
||||
|
||||
loaded_tensor_names.insert(name);
|
||||
|
||||
return tensor;
|
||||
}
|
||||
|
||||
void llama_model_loader::done_getting_tensors() const {
|
||||
if (n_created != n_tensors) {
|
||||
// for debugging
|
||||
for (const auto & it : weights_map) {
|
||||
const std::string & name = it.first;
|
||||
if (loaded_tensor_names.find(name) == loaded_tensor_names.end()) {
|
||||
LLAMA_LOG_DEBUG("%s: tensor '%s' was not created\n", __func__, name.c_str());
|
||||
}
|
||||
}
|
||||
throw std::runtime_error(format("%s: wrong number of tensors; expected %d, got %d", __func__, n_tensors, n_created));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
#include <cstddef>
|
||||
#include <map>
|
||||
#include <set>
|
||||
#include <stdexcept>
|
||||
#include <unordered_map>
|
||||
|
||||
|
|
@ -94,6 +95,8 @@ struct llama_model_loader {
|
|||
size_t size_data = 0;
|
||||
std::vector<std::pair<size_t, size_t>> mmaps_used;
|
||||
|
||||
std::set<std::string> loaded_tensor_names; // for debugging
|
||||
|
||||
llama_model_loader(
|
||||
const std::string & fname,
|
||||
std::vector<std::string> & splits, // optional, only need if the split does not follow naming scheme
|
||||
|
|
|
|||
|
|
@ -1695,6 +1695,7 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
}
|
||||
} break;
|
||||
case LLM_ARCH_DEEPSEEK2:
|
||||
case LLM_ARCH_LONGCAT_FLASH:
|
||||
{
|
||||
// lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B
|
||||
const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26);
|
||||
|
|
@ -1733,6 +1734,9 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.n_attn_temp_floor_scale, false);
|
||||
|
||||
// (optional) n_zero_experts - used by longcat-flash
|
||||
ml.get_key(LLM_KV_N_ZERO_EXPERTS, hparams.n_zero_experts, false);
|
||||
|
||||
hparams.f_attn_temp_offset = 0.0f;
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
|
|
@ -6971,6 +6975,88 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_LONGCAT_FLASH:
|
||||
{
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
|
||||
const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla();
|
||||
const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope;
|
||||
|
||||
const int64_t q_lora_rank = hparams.n_lora_q;
|
||||
const int64_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp;
|
||||
const int64_t n_expert_full = n_expert + hparams.n_zero_experts;
|
||||
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
// try to load output.weight, if not found, use token_embd (tied embeddings)
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
if (!output) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
if (!is_mla) { throw std::runtime_error("mla is required"); }
|
||||
if (q_lora_rank <= 0) { throw std::runtime_error("q_lora_rank must be > 0"); }
|
||||
if (n_expert == 0) { throw std::runtime_error("n_expert must be > 0"); }
|
||||
if (n_expert_used == 0) { throw std::runtime_error("n_expert_used must be > 0"); }
|
||||
// NOTE: large part of the code is copied from deepseek2
|
||||
// main difference is that longcat has zero experts and not all layers are MoE
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, 0);
|
||||
|
||||
layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, 0);
|
||||
|
||||
layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, 0);
|
||||
layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, 0);
|
||||
|
||||
layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, 0);
|
||||
|
||||
layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, 0);
|
||||
layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, 0);
|
||||
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
// try to see if this is a dense or MoE layer
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert_full}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
bool is_moe = (layer.ffn_gate_inp != nullptr);
|
||||
if (is_moe && (i % 2 != 0)) {
|
||||
throw std::runtime_error("MoE layers must be at even indices");
|
||||
}
|
||||
|
||||
if (!is_moe) {
|
||||
// dense
|
||||
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
} else {
|
||||
// MoE
|
||||
layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert_full}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
// shared experts
|
||||
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0);
|
||||
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0);
|
||||
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -7311,7 +7397,7 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale);
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_DEEPSEEK2) {
|
||||
if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_LONGCAT_FLASH) {
|
||||
LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead);
|
||||
LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q);
|
||||
LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv);
|
||||
|
|
@ -8086,6 +8172,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_mimo2_iswa>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_LONGCAT_FLASH:
|
||||
{
|
||||
llm = std::make_unique<llm_build_longcat_flash>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -8268,6 +8358,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_MISTRAL3:
|
||||
case LLM_ARCH_LLAMA_EMBED:
|
||||
case LLM_ARCH_MAINCODER:
|
||||
case LLM_ARCH_LONGCAT_FLASH:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
|
|
|||
|
|
@ -468,6 +468,17 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
|||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?(?:\\p{L}\\p{M}*(?: \\p{L}\\p{M}*)*)+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]|\\s+(?!\\S)|\\s+",
|
||||
};
|
||||
break;
|
||||
case LLAMA_VOCAB_PRE_TYPE_LONGCAT_FLASH:
|
||||
regex_exprs = {
|
||||
// original regex from tokenizer.json
|
||||
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+"
|
||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]?|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||
// " ?[!-/:-~‘-‟ -。《》「」【】]+"
|
||||
" ?[\uff01-\uff0f\uff1a-\uff5e'-\u201f\u3000-\u3002\u300a\u300b\u300c\u300d\u3010\u3011]+",
|
||||
// "[一-龥ࠀ-一가-]+"
|
||||
"[\u4e00-\u9fa5\u0800-\u4e00\uac00-\ud7ff]+",
|
||||
};
|
||||
break;
|
||||
default:
|
||||
// default regex for BPE tokenization pre-processing
|
||||
regex_exprs = {
|
||||
|
|
@ -2041,6 +2052,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
|||
tokenizer_pre == "solar-open") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN;
|
||||
clean_spaces = false;
|
||||
} else if (
|
||||
tokenizer_pre == "longcat-flash") {
|
||||
pre_type = LLAMA_VOCAB_PRE_TYPE_LONGCAT_FLASH;
|
||||
clean_spaces = false;
|
||||
} else {
|
||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,6 +54,7 @@ enum llama_vocab_pre_type {
|
|||
LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43,
|
||||
LLAMA_VOCAB_PRE_TYPE_YOUTU = 44,
|
||||
LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45,
|
||||
LLAMA_VOCAB_PRE_TYPE_LONGCAT_FLASH = 46,
|
||||
};
|
||||
|
||||
struct LLM_KV;
|
||||
|
|
|
|||
|
|
@ -0,0 +1,210 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_longcat_flash::llm_build_longcat_flash(const llama_model & model, const llm_graph_params & params) :
|
||||
llm_graph_context(params) {
|
||||
const bool is_mla = hparams.is_mla();
|
||||
|
||||
// note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA
|
||||
const int64_t n_embd_head_k = hparams.n_embd_head_k_mla();
|
||||
// const int64_t n_embd_head_v = hparams.n_embd_head_v_mla();
|
||||
|
||||
const int64_t n_embd_head_qk_rope = hparams.n_rot;
|
||||
const int64_t n_embd_head_qk_nope = n_embd_head_k - n_embd_head_qk_rope;
|
||||
|
||||
const uint32_t kv_lora_rank = hparams.n_lora_kv;
|
||||
|
||||
// large part of the code is copied from deepseek2
|
||||
// we only use a subset of features here
|
||||
// TODO: dedup the code by abstracting common parts
|
||||
GGML_ASSERT(is_mla);
|
||||
GGML_ASSERT(kv_lora_rank > 0);
|
||||
|
||||
// longcat-flash uses double attention + MLP, so n_layer must be even
|
||||
GGML_ASSERT(n_layer % 2 == 0);
|
||||
|
||||
const float kq_scale = 1.0f / sqrtf(float(n_embd_head_k));
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
// {n_embd, n_tokens}
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn_k = build_attn_inp_k(); // MLA-only
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
ggml_tensor * q = NULL;
|
||||
|
||||
///////// MLA implementation - exactly the same as deepseek2 /////////
|
||||
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_a, cur);
|
||||
cb(q, "q", il);
|
||||
|
||||
q = build_norm(q, model.layers[il].attn_q_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(q, "q", il);
|
||||
|
||||
q = ggml_mul_mat(ctx0, model.layers[il].wq_b, q);
|
||||
cb(q, "q", il);
|
||||
|
||||
// split into {n_embd_head_qk_nope, n_head, n_tokens}
|
||||
ggml_tensor * q_nope =
|
||||
ggml_view_3d(ctx0, q, n_embd_head_qk_nope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k),
|
||||
ggml_row_size(q->type, n_embd_head_k) * n_head, 0);
|
||||
cb(q_nope, "q_nope", il);
|
||||
|
||||
// and {n_embd_head_qk_rope, n_head, n_tokens}
|
||||
ggml_tensor * q_pe = ggml_view_3d(
|
||||
ctx0, q, n_embd_head_qk_rope, n_head, n_tokens, ggml_row_size(q->type, n_embd_head_k),
|
||||
ggml_row_size(q->type, n_embd_head_k) * n_head, ggml_row_size(q->type, n_embd_head_qk_nope));
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
ggml_tensor * kv_cmpr_pe = ggml_mul_mat(ctx0, model.layers[il].wkv_a_mqa, cur);
|
||||
cb(kv_cmpr_pe, "kv_cmpr_pe", il);
|
||||
|
||||
// split into {kv_lora_rank, n_tokens}
|
||||
ggml_tensor * kv_cmpr =
|
||||
ggml_view_2d(ctx0, kv_cmpr_pe, kv_lora_rank, n_tokens,
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope), 0);
|
||||
cb(kv_cmpr, "kv_cmpr", il);
|
||||
|
||||
// and {n_embd_head_qk_rope, 1, n_tokens}
|
||||
ggml_tensor * k_pe = ggml_view_3d(ctx0, kv_cmpr_pe, n_embd_head_qk_rope, 1, n_tokens,
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank + n_embd_head_qk_rope),
|
||||
ggml_row_size(kv_cmpr_pe->type, kv_lora_rank));
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
q_pe = ggml_rope_ext(ctx0, q_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(q_pe, "q_pe", il);
|
||||
|
||||
k_pe = ggml_rope_ext(ctx0, k_pe, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow);
|
||||
cb(k_pe, "k_pe", il);
|
||||
|
||||
kv_cmpr = build_norm(kv_cmpr, model.layers[il].attn_kv_a_norm, nullptr, LLM_NORM_RMS, il);
|
||||
cb(kv_cmpr, "kv_cmpr", il);
|
||||
|
||||
{
|
||||
// {n_embd_head_qk_nope, n_tokens, n_head}
|
||||
q_nope = ggml_permute(ctx0, q_nope, 0, 2, 1, 3);
|
||||
cb(q_nope, "q_nope_perm", il);
|
||||
|
||||
// {n_embd_head_qk_nope, kv_lora_rank, n_head} x {n_embd_head_qk_nope, n_tokens, n_head}
|
||||
ggml_tensor * q_nope_absorbed = ggml_mul_mat(ctx0, model.layers[il].wk_b, q_nope);
|
||||
cb(q_nope_absorbed, "q_nope_absorbed", il);
|
||||
|
||||
// {kv_lora_rank, n_head, n_tokens}
|
||||
q_nope_absorbed = ggml_permute(ctx0, q_nope_absorbed, 0, 2, 1, 3);
|
||||
cb(q_nope_absorbed, "q_nope_absorbed_perm", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, n_head, n_tokens}
|
||||
// note: rope must go first for in-place context shifting in build_rope_shift()
|
||||
ggml_tensor * Qcur = ggml_concat(ctx0, q_nope_absorbed, q_pe, 0);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
kv_cmpr = ggml_reshape_3d(ctx0, kv_cmpr, kv_lora_rank, 1, n_tokens);
|
||||
cb(kv_cmpr, "kv_cmpr_reshape", il);
|
||||
|
||||
// {n_embd_head_qk_rope + kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Kcur = ggml_concat(ctx0, kv_cmpr, k_pe, 0);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
// {kv_lora_rank, 1, n_tokens}
|
||||
ggml_tensor * Vcur = kv_cmpr;
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
// note: MLA with the absorption optimzation converts into MQA (ie: GQA with 1 group)
|
||||
cur = build_attn(inp_attn_k,
|
||||
model.layers[il].wo, NULL,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, model.layers[il].wv_b, kq_scale, il);
|
||||
}
|
||||
|
||||
///////// End of MLA implementation /////////
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
cur = build_norm(ffn_inp, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
bool is_moe = model.layers[il].ffn_gate_inp != nullptr;
|
||||
|
||||
if (!is_moe) {
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, NULL, NULL,
|
||||
model.layers[il].ffn_gate, NULL, NULL,
|
||||
model.layers[il].ffn_down, NULL, NULL,
|
||||
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
model.layers[il].ffn_exp_probs_b,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, hparams.expert_weights_norm,
|
||||
hparams.expert_weights_scale, hparams.expert_weights_scale,
|
||||
(llama_expert_gating_func_type) hparams.expert_gating_func,
|
||||
il);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
|
||||
// FFN shared expert
|
||||
{
|
||||
ggml_tensor * ffn_shexp =
|
||||
build_ffn(cur,
|
||||
model.layers[il].ffn_up_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_gate_shexp, NULL, NULL,
|
||||
model.layers[il].ffn_down_shexp, NULL, NULL,
|
||||
NULL, LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(ffn_shexp, "ffn_shexp", il);
|
||||
|
||||
cur = ggml_add(ctx0, moe_out, ffn_shexp);
|
||||
cb(cur, "ffn_out", il);
|
||||
}
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = ggml_mul_mat(ctx0, model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -316,6 +316,10 @@ struct llm_build_llama_iswa : public llm_graph_context {
|
|||
llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_longcat_flash : public llm_graph_context {
|
||||
llm_build_longcat_flash(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_maincoder : public llm_graph_context {
|
||||
llm_build_maincoder(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue