kimi linear convert_hf_to_gguf
This commit is contained in:
parent
27baad43d5
commit
84f822c5a5
|
|
@ -563,6 +563,10 @@ class ModelBase:
|
|||
gguf.MODEL_TENSOR.A_ENC_EMBD_POS,
|
||||
gguf.MODEL_TENSOR.ALTUP_CORRECT_COEF,
|
||||
gguf.MODEL_TENSOR.ALTUP_PREDICT_COEF,
|
||||
# Kimi KDA conv weights should be F32
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D_Q,
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D_K,
|
||||
gguf.MODEL_TENSOR.SSM_CONV1D_V,
|
||||
)
|
||||
)
|
||||
or new_name[-7:] not in (".weight", ".lora_a", ".lora_b")
|
||||
|
|
@ -4976,6 +4980,295 @@ class CodeShellModel(TextModel):
|
|||
self.gguf_writer.add_rope_scaling_factor(1.0)
|
||||
|
||||
|
||||
@ModelBase.register("KimiLinearModel", "KimiLinearForCausalLM")
|
||||
class KimiLinearModel(TextModel):
|
||||
"""Kimi-Linear model with hybrid MLA+KDA architecture"""
|
||||
model_arch = gguf.MODEL_ARCH.KIMI_LINEAR
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
|
||||
|
||||
# Use find_hparam for context length
|
||||
# Kimi uses model_max_length
|
||||
n_ctx = self.find_hparam(["max_position_embeddings", "model_max_length", "n_ctx", "n_positions"], optional=True)
|
||||
if n_ctx is not None:
|
||||
self.gguf_writer.add_context_length(n_ctx)
|
||||
else:
|
||||
# Default to 4096 if not found
|
||||
logger.warning("No context length found in config, defaulting to 4096")
|
||||
self.gguf_writer.add_context_length(4096)
|
||||
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
# KDA & MLA params
|
||||
# Get ssm_d_conv from linear_attn_config.short_conv_kernel_size or ssm_d_conv
|
||||
linear_attn_config = self.hparams.get("linear_attn_config", {})
|
||||
ssm_d_conv = self.hparams.get("ssm_d_conv") or linear_attn_config.get("short_conv_kernel_size")
|
||||
if ssm_d_conv is not None:
|
||||
self.gguf_writer.add_ssm_conv_kernel(ssm_d_conv)
|
||||
|
||||
# MLA params - use add_* methods that handle arch substitution
|
||||
# Support both HuggingFace naming (q_lora_rank, kv_lora_rank) and internal naming (n_lora_q, n_lora_kv)
|
||||
q_lora_rank = self.hparams.get("q_lora_rank", self.hparams.get("n_lora_q"))
|
||||
kv_lora_rank = self.hparams.get("kv_lora_rank", self.hparams.get("n_lora_kv"))
|
||||
|
||||
if q_lora_rank is not None:
|
||||
self.gguf_writer.add_q_lora_rank(q_lora_rank)
|
||||
if kv_lora_rank is not None:
|
||||
self.gguf_writer.add_kv_lora_rank(kv_lora_rank)
|
||||
|
||||
# MLA head dimensions
|
||||
# Support HuggingFace naming: qk_nope_head_dim, qk_rope_head_dim, v_head_dim
|
||||
qk_nope_head_dim = self.hparams.get("qk_nope_head_dim")
|
||||
qk_rope_head_dim = self.hparams.get("qk_rope_head_dim", self.hparams.get("n_rot"))
|
||||
v_head_dim = self.hparams.get("v_head_dim")
|
||||
|
||||
# Calculate n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
|
||||
if "n_embd_head_k_mla" in self.hparams:
|
||||
self.gguf_writer.add_key_length_mla(self.hparams["n_embd_head_k_mla"])
|
||||
elif qk_nope_head_dim is not None and qk_rope_head_dim is not None:
|
||||
n_embd_head_k_mla = qk_nope_head_dim + qk_rope_head_dim
|
||||
self.gguf_writer.add_key_length_mla(n_embd_head_k_mla)
|
||||
|
||||
# n_embd_head_v_mla = v_head_dim
|
||||
if "n_embd_head_v_mla" in self.hparams:
|
||||
self.gguf_writer.add_value_length_mla(self.hparams["n_embd_head_v_mla"])
|
||||
elif v_head_dim is not None:
|
||||
self.gguf_writer.add_value_length_mla(v_head_dim)
|
||||
|
||||
# Rotation - use qk_rope_head_dim for Kimi
|
||||
rope_dim = self.hparams.get("qk_rope_head_dim") or self.hparams.get("n_rot")
|
||||
if rope_dim is not None:
|
||||
self.gguf_writer.add_rope_dimension_count(rope_dim)
|
||||
else:
|
||||
# Default to head_dim
|
||||
head_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
self.gguf_writer.add_rope_dimension_count(head_dim)
|
||||
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams.get("rope_theta", 10000.0))
|
||||
|
||||
# MoE params
|
||||
n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts"))
|
||||
if n_experts is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
# Support both num_experts_per_tok and num_experts_per_token
|
||||
n_experts_used = self.hparams.get("num_experts_per_tok", self.hparams.get("num_experts_per_token"))
|
||||
if n_experts_used is not None:
|
||||
self.gguf_writer.add_expert_used_count(n_experts_used)
|
||||
|
||||
# moe_intermediate_size (1024 for Kimi)
|
||||
moe_intermediate_size = self.hparams.get("moe_intermediate_size")
|
||||
if moe_intermediate_size is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
|
||||
# num_shared_experts (1 for Kimi)
|
||||
num_shared_experts = self.hparams.get("num_shared_experts")
|
||||
if num_shared_experts is not None:
|
||||
self.gguf_writer.add_expert_shared_count(num_shared_experts)
|
||||
|
||||
# first_k_dense_replace (1 for Kimi - first layer uses dense MLP)
|
||||
first_k_dense_replace = self.hparams.get("first_k_dense_replace")
|
||||
if first_k_dense_replace is not None:
|
||||
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
|
||||
|
||||
# Expert gating function (sigmoid for Kimi)
|
||||
moe_router_activation_func = self.hparams.get("moe_router_activation_func", "sigmoid")
|
||||
if moe_router_activation_func == "sigmoid":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
elif moe_router_activation_func == "softmax":
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SOFTMAX)
|
||||
else:
|
||||
logger.warning(f"Unknown moe_router_activation_func: {moe_router_activation_func}, defaulting to sigmoid")
|
||||
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
|
||||
|
||||
# Routed scaling factor (expert_weights_scale = 2.446 for Kimi)
|
||||
routed_scaling_factor = self.hparams.get("routed_scaling_factor")
|
||||
if routed_scaling_factor is not None:
|
||||
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
|
||||
|
||||
def set_vocab(self):
|
||||
# Kimi uses TikToken tokenizer - load via transformers
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
dir_model = self.dir_model
|
||||
vocab_size = self.hparams["vocab_size"]
|
||||
|
||||
logger.info(f"Loading TikToken tokenizer from {dir_model}")
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model, trust_remote_code=True)
|
||||
|
||||
tokens: list[str] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
# Get tokenizer pre string
|
||||
tokpre = self.get_vocab_base_pre(tokenizer)
|
||||
|
||||
# Build vocab from tokenizer
|
||||
merges = []
|
||||
vocab = {}
|
||||
|
||||
# TikToken stores vocab in mergeable_ranks
|
||||
if hasattr(tokenizer, 'mergeable_ranks'):
|
||||
mergeable_ranks = tokenizer.mergeable_ranks
|
||||
for token, rank in mergeable_ranks.items():
|
||||
vocab[self._token_bytes_to_string(token)] = rank
|
||||
if len(token) == 1:
|
||||
continue
|
||||
# Build merges
|
||||
merged = self._bpe(mergeable_ranks, token, max_rank=rank)
|
||||
if len(merged) == 2:
|
||||
merges.append(' '.join(map(self._token_bytes_to_string, merged)))
|
||||
else:
|
||||
# Fallback: get vocab directly
|
||||
vocab = {tok: idx for tok, idx in tokenizer.get_vocab().items()}
|
||||
|
||||
# Get special tokens
|
||||
added_vocab = {}
|
||||
if hasattr(tokenizer, 'special_tokens'):
|
||||
added_vocab = tokenizer.special_tokens
|
||||
elif hasattr(tokenizer, 'added_tokens_encoder'):
|
||||
added_vocab = tokenizer.added_tokens_encoder
|
||||
|
||||
# Combine vocab
|
||||
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in {**vocab, **added_vocab}.items()}
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
tokens.append(f"[PAD{i}]")
|
||||
toktypes.append(gguf.TokenType.UNUSED)
|
||||
elif i in added_vocab.values() if added_vocab else False:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_tokenizer_pre(tokpre)
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=False)
|
||||
special_vocab.merges = merges
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
logger.info(f"Loaded {len(tokens)} tokens, {len(merges)} merges")
|
||||
|
||||
@staticmethod
|
||||
def _token_bytes_to_string(b: bytes) -> str:
|
||||
"""Convert bytes to string representation for tokenizer"""
|
||||
return ''.join([chr(byte) if byte < 128 else f'<0x{byte:02X}>' for byte in b])
|
||||
|
||||
@staticmethod
|
||||
def _bpe(mergeable_ranks: dict[bytes, int], token: bytes, max_rank: int | None = None) -> list[bytes]:
|
||||
"""BPE tokenization for merges extraction"""
|
||||
parts = [bytes([b]) for b in token]
|
||||
while True:
|
||||
min_idx = None
|
||||
min_rank = None
|
||||
for i, pair in enumerate(zip(parts[:-1], parts[1:])):
|
||||
rank = mergeable_ranks.get(pair[0] + pair[1])
|
||||
if rank is not None and (min_rank is None or rank < min_rank):
|
||||
min_idx = i
|
||||
min_rank = rank
|
||||
if min_rank is None or (max_rank is not None and min_rank >= max_rank):
|
||||
break
|
||||
parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
|
||||
return parts
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
if self._experts is not None:
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
logger.info(f"Processing {name}: shape before = {tuple(data_torch.shape)}")
|
||||
|
||||
# Handle KDA conv1d weights
|
||||
# HuggingFace/vLLM stores as [d_inner, d_conv] (2D), memory layout: conv_step changes fastest
|
||||
# llama.cpp expects ggml ne = [d_conv, 1, d_inner, 1], memory layout: ne[0]=d_conv changes fastest
|
||||
# GGUF reverses numpy shape when writing, so numpy (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1]
|
||||
# Memory layouts match: both have conv_step (d_conv) changing fastest
|
||||
if name.endswith((".q_conv1d.weight", ".k_conv1d.weight", ".v_conv1d.weight")):
|
||||
# HF shape: [d_inner, d_conv] e.g. [4096, 4]
|
||||
# Target numpy shape: (1, d_inner, 1, d_conv) -> ggml ne = [d_conv, 1, d_inner, 1]
|
||||
if data_torch.ndim == 2:
|
||||
d_inner, d_conv = data_torch.shape
|
||||
# Reshape to (1, d_inner, 1, d_conv) - memory layout preserved (d_conv fastest)
|
||||
data_torch = data_torch.reshape(1, d_inner, 1, d_conv)
|
||||
logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]")
|
||||
elif data_torch.ndim == 3:
|
||||
# Already 3D [d_inner, 1, d_conv] from unsqueeze
|
||||
d_inner, _, d_conv = data_torch.shape
|
||||
data_torch = data_torch.reshape(1, d_inner, 1, d_conv)
|
||||
logger.info(f"Reshaped conv1d weight {name}: [d_inner={d_inner}, 1, d_conv={d_conv}] -> numpy {tuple(data_torch.shape)} -> ggml ne=[{d_conv}, 1, {d_inner}, 1]")
|
||||
|
||||
# Handle A_log: HF stores as [1, 1, num_heads, 1]
|
||||
# llama.cpp expects ggml ne = [1, num_heads, 1, 1]
|
||||
# GGUF reverses numpy shape: numpy (1, 1, num_heads, 1) -> ggml ne = [1, num_heads, 1, 1]
|
||||
# So no transformation needed! The shapes already match after GGUF reversal.
|
||||
if name.endswith(".A_log"):
|
||||
if data_torch.ndim == 4:
|
||||
logger.info(f"A_log {name}: numpy {tuple(data_torch.shape)} -> ggml ne={list(reversed(data_torch.shape))}")
|
||||
|
||||
# Kimi specific bias
|
||||
if name.endswith("block_sparse_moe.gate.e_score_correction_bias"):
|
||||
new_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_EXP_PROBS_B, bid)
|
||||
return [(new_name, data_torch)]
|
||||
|
||||
# process the experts separately
|
||||
if name.find("block_sparse_moe.experts") != -1:
|
||||
n_experts = self.hparams.get("num_local_experts", self.hparams.get("num_experts"))
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
# merge the experts into a single 3d tensor
|
||||
tensors = []
|
||||
# w1: gate, w2: down, w3: up
|
||||
for wid, tname in [("w1", gguf.MODEL_TENSOR.FFN_GATE_EXP),
|
||||
("w2", gguf.MODEL_TENSOR.FFN_DOWN_EXP),
|
||||
("w3", gguf.MODEL_TENSOR.FFN_UP_EXP)]:
|
||||
datas: list[Tensor] = []
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
new_name = self.format_tensor_name(tname, bid)
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
return []
|
||||
|
||||
mapped_name = self.map_tensor_name(name)
|
||||
logger.info(f"Returning {mapped_name}: shape after = {tuple(data_torch.shape)}")
|
||||
return [(mapped_name, data_torch)]
|
||||
|
||||
def get_vocab_base(self) -> tuple[list[str], list[int], str]:
|
||||
# This method is not used when set_vocab is overridden
|
||||
# But adding it for completeness in case it's called elsewhere
|
||||
logger.warning("get_vocab_base called, but set_vocab is already overridden")
|
||||
vocab_size = self.hparams.get("vocab_size", 100)
|
||||
tokens = [f"<token_{i}>" for i in range(vocab_size)]
|
||||
tokens[0] = "<unk>"
|
||||
tokens[1] = "<s>"
|
||||
tokens[2] = "</s>"
|
||||
toktypes = [gguf.TokenType.NORMAL] * vocab_size
|
||||
return tokens, toktypes, "gpt-2"
|
||||
|
||||
@ModelBase.register("InternLM2ForCausalLM")
|
||||
class InternLM2Model(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.INTERNLM2
|
||||
|
|
|
|||
|
|
@ -283,6 +283,12 @@ struct llm_build_jamba : public llm_graph_context_mamba {
|
|||
llm_build_jamba(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_kimi_linear : public llm_graph_context_mamba {
|
||||
llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params);
|
||||
private:
|
||||
const llama_model & model;
|
||||
};
|
||||
|
||||
struct llm_build_lfm2 : public llm_graph_context {
|
||||
const llama_model & model;
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue