model : add GroveMoE support (#15510)
* add GroveMoE support * remove constexpr that fails on certain compilers * revert crude scalar div implementation, use cast * build_attn_inp_kv_unified -> build_attn_inp_kv * fix build_attn * re-apply ffn_exps regex changes
This commit is contained in:
parent
b05a9d650f
commit
835b2b915c
|
|
@ -738,7 +738,7 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count";
|
|||
// MoE utils
|
||||
//
|
||||
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_exps";
|
||||
const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps";
|
||||
|
||||
static std::string llm_ffn_exps_block_regex(int idx) {
|
||||
return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX);
|
||||
|
|
|
|||
|
|
@ -7995,6 +7995,121 @@ class BailingMoeModel(TextModel):
|
|||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("GroveMoeForCausalLM", "modeling_grove_moe.GroveMoeForCausalLM")
|
||||
class GroveMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.GROVEMOE
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if (n_experts := self.hparams.get("num_experts")) is not None:
|
||||
self.gguf_writer.add_expert_count(n_experts)
|
||||
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||
logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}")
|
||||
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L299
|
||||
self.gguf_writer.add_expert_chunk_feed_forward_length(self.hparams.get("head_dim") or 128)
|
||||
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L298
|
||||
self.gguf_writer.add_experts_per_group(2)
|
||||
# FIXME?: Hardcoded https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L376
|
||||
self.gguf_writer.add_expert_group_scale(0.05)
|
||||
# YaRN is not enabled by default
|
||||
# To enable it, please refer to this guide: https://huggingface.co/Qwen/Qwen3-30B-A3B#processing-long-texts
|
||||
rope_scaling = self.hparams.get("rope_scaling") or {}
|
||||
if rope_scaling.get("rope_type", rope_scaling.get("type")) == "yarn" and "factor" in rope_scaling:
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_scaling["factor"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_scaling["original_max_position_embeddings"])
|
||||
|
||||
_experts: list[dict[str, Tensor]] | None = None
|
||||
_chunk_experts: list[dict[str, Tensor]] | None = None
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
if name.endswith(".expert_bias"):
|
||||
# FIXME?: Unused https://huggingface.co/inclusionAI/GroveMoE-Inst/blob/c4c69e5970d18907b5e6ddccdfd55176fe292df1/modeling_grove_moe.py#L303
|
||||
return []
|
||||
|
||||
# process the experts separately
|
||||
if name.find("chunk_experts") != -1:
|
||||
n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group
|
||||
assert bid is not None
|
||||
|
||||
if self._chunk_experts is None:
|
||||
self._chunk_experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._chunk_experts[bid][name] = data_torch
|
||||
|
||||
if len(self._chunk_experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.chunk_experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._chunk_experts[bid][ename])
|
||||
del self._chunk_experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.chunk_experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
elif name.find("experts") != -1:
|
||||
n_experts = self.hparams["num_experts"]
|
||||
assert bid is not None
|
||||
|
||||
if self._experts is None:
|
||||
self._experts = [{} for _ in range(self.block_count)]
|
||||
|
||||
self._experts[bid][name] = data_torch
|
||||
|
||||
if len(self._experts[bid]) >= n_experts * 3:
|
||||
tensors: list[tuple[str, Tensor]] = []
|
||||
|
||||
# merge the experts into a single 3d tensor
|
||||
for w_name in ["down_proj", "gate_proj", "up_proj"]:
|
||||
datas: list[Tensor] = []
|
||||
|
||||
for xid in range(n_experts):
|
||||
ename = f"model.layers.{bid}.mlp.experts.{xid}.{w_name}.weight"
|
||||
datas.append(self._experts[bid][ename])
|
||||
del self._experts[bid][ename]
|
||||
|
||||
data_torch = torch.stack(datas, dim=0)
|
||||
|
||||
merged_name = f"model.layers.{bid}.mlp.experts.{w_name}.weight"
|
||||
|
||||
new_name = self.map_tensor_name(merged_name)
|
||||
|
||||
tensors.append((new_name, data_torch))
|
||||
return tensors
|
||||
else:
|
||||
return []
|
||||
|
||||
return [(self.map_tensor_name(name), data_torch)]
|
||||
|
||||
def prepare_tensors(self):
|
||||
super().prepare_tensors()
|
||||
|
||||
if self._chunk_experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
chunk_experts = [k for d in self._chunk_experts for k in d.keys()]
|
||||
if len(chunk_experts) > 0:
|
||||
raise ValueError(f"Unprocessed adjugate experts: {chunk_experts}")
|
||||
|
||||
if self._experts is not None:
|
||||
# flatten `list[dict[str, Tensor]]` into `list[str]`
|
||||
experts = [k for d in self._experts for k in d.keys()]
|
||||
if len(experts) > 0:
|
||||
raise ValueError(f"Unprocessed experts: {experts}")
|
||||
|
||||
|
||||
@ModelBase.register("ChameleonForConditionalGeneration")
|
||||
@ModelBase.register("ChameleonForCausalLM") # obsolete
|
||||
class ChameleonModel(TextModel):
|
||||
|
|
|
|||
|
|
@ -96,6 +96,7 @@ class Keys:
|
|||
FEED_FORWARD_LENGTH = "{arch}.feed_forward_length"
|
||||
EXPERT_FEED_FORWARD_LENGTH = "{arch}.expert_feed_forward_length"
|
||||
EXPERT_SHARED_FEED_FORWARD_LENGTH = "{arch}.expert_shared_feed_forward_length"
|
||||
EXPERT_CHUNK_FEED_FORWARD_LENGTH = "{arch}.expert_chunk_feed_forward_length"
|
||||
USE_PARALLEL_RESIDUAL = "{arch}.use_parallel_residual"
|
||||
TENSOR_DATA_LAYOUT = "{arch}.tensor_data_layout"
|
||||
EXPERT_COUNT = "{arch}.expert_count"
|
||||
|
|
@ -104,6 +105,8 @@ class Keys:
|
|||
EXPERT_WEIGHTS_SCALE = "{arch}.expert_weights_scale"
|
||||
EXPERT_WEIGHTS_NORM = "{arch}.expert_weights_norm"
|
||||
EXPERT_GATING_FUNC = "{arch}.expert_gating_func"
|
||||
EXPERT_GROUP_SCALE = "{arch}.expert_group_scale"
|
||||
EXPERTS_PER_GROUP = "{arch}.experts_per_group"
|
||||
MOE_EVERY_N_LAYERS = "{arch}.moe_every_n_layers"
|
||||
NEXTN_PREDICT_LAYERS = "{arch}.nextn_predict_layers"
|
||||
POOLING_TYPE = "{arch}.pooling_type"
|
||||
|
|
@ -401,6 +404,7 @@ class MODEL_ARCH(IntEnum):
|
|||
LLADA = auto()
|
||||
LLADA_MOE = auto()
|
||||
SEED_OSS = auto()
|
||||
GROVEMOE = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -450,6 +454,9 @@ class MODEL_TENSOR(IntEnum):
|
|||
FFN_GATE_SHEXP = auto()
|
||||
FFN_DOWN_SHEXP = auto()
|
||||
FFN_UP_SHEXP = auto()
|
||||
FFN_GATE_CHEXP = auto()
|
||||
FFN_DOWN_CHEXP = auto()
|
||||
FFN_UP_CHEXP = auto()
|
||||
FFN_EXP_PROBS_B = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
|
|
@ -738,6 +745,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.LLADA: "llada",
|
||||
MODEL_ARCH.LLADA_MOE: "llada-moe",
|
||||
MODEL_ARCH.SEED_OSS: "seed_oss",
|
||||
MODEL_ARCH.GROVEMOE: "grovemoe",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -784,6 +792,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.FFN_GATE_SHEXP: "blk.{bid}.ffn_gate_shexp",
|
||||
MODEL_TENSOR.FFN_DOWN_SHEXP: "blk.{bid}.ffn_down_shexp",
|
||||
MODEL_TENSOR.FFN_UP_SHEXP: "blk.{bid}.ffn_up_shexp",
|
||||
MODEL_TENSOR.FFN_GATE_CHEXP: "blk.{bid}.ffn_gate_chexps",
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP: "blk.{bid}.ffn_down_chexps",
|
||||
MODEL_TENSOR.FFN_UP_CHEXP: "blk.{bid}.ffn_up_chexps",
|
||||
MODEL_TENSOR.FFN_ACT: "blk.{bid}.ffn",
|
||||
MODEL_TENSOR.FFN_NORM_EXP: "blk.{bid}.ffn_norm_exps",
|
||||
MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps",
|
||||
|
|
@ -2712,6 +2723,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
],
|
||||
MODEL_ARCH.GROVEMOE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_Q_NORM,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_K_NORM,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
MODEL_TENSOR.FFN_GATE_CHEXP,
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP,
|
||||
MODEL_TENSOR.FFN_UP_CHEXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -670,6 +670,9 @@ class GGUFWriter:
|
|||
def add_expert_shared_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_SHARED_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_expert_chunk_feed_forward_length(self, length: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_CHUNK_FEED_FORWARD_LENGTH.format(arch=self.arch), length)
|
||||
|
||||
def add_parallel_residual(self, use: bool) -> None:
|
||||
self.add_bool(Keys.LLM.USE_PARALLEL_RESIDUAL.format(arch=self.arch), use)
|
||||
|
||||
|
|
@ -757,6 +760,12 @@ class GGUFWriter:
|
|||
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
|
||||
|
||||
def add_expert_group_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_experts_per_group(self, count: int) -> None:
|
||||
self.add_uint32(Keys.LLM.EXPERTS_PER_GROUP.format(arch=self.arch), count)
|
||||
|
||||
def add_moe_every_n_layers(self, value: int) -> None:
|
||||
self.add_uint32(Keys.LLM.MOE_EVERY_N_LAYERS.format(arch=self.arch), value)
|
||||
|
||||
|
|
|
|||
|
|
@ -427,6 +427,10 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_mlp.up_proj", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_CHEXP: (
|
||||
"model.layers.{bid}.mlp.chunk_experts.up_proj", # grovemoe
|
||||
),
|
||||
|
||||
# AWQ-activation gate
|
||||
MODEL_TENSOR.FFN_ACT: (
|
||||
"transformer.blocks.{bid}.ffn.act", # mpt
|
||||
|
|
@ -468,6 +472,10 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_GATE_CHEXP: (
|
||||
"model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe
|
||||
),
|
||||
|
||||
# Feed-forward down
|
||||
MODEL_TENSOR.FFN_DOWN: (
|
||||
"gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox
|
||||
|
|
@ -524,6 +532,10 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.mlp.shared_mlp.down_proj", # hunyuan
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_CHEXP: (
|
||||
"model.layers.{bid}.mlp.chunk_experts.down_proj", # grovemoe
|
||||
),
|
||||
|
||||
MODEL_TENSOR.ATTN_Q_NORM: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.q_layernorm",
|
||||
"model.layers.{bid}.self_attn.q_layernorm", # persimmon
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_LLADA, "llada" },
|
||||
{ LLM_ARCH_LLADA_MOE, "llada-moe" },
|
||||
{ LLM_ARCH_SEED_OSS, "seed_oss" },
|
||||
{ LLM_ARCH_GROVEMOE, "grovemoe" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -125,6 +126,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_FEED_FORWARD_LENGTH, "%s.feed_forward_length" },
|
||||
{ LLM_KV_EXPERT_FEED_FORWARD_LENGTH, "%s.expert_feed_forward_length" },
|
||||
{ LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, "%s.expert_shared_feed_forward_length" },
|
||||
{ LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, "%s.expert_chunk_feed_forward_length" },
|
||||
{ LLM_KV_USE_PARALLEL_RESIDUAL, "%s.use_parallel_residual" },
|
||||
{ LLM_KV_TENSOR_DATA_LAYOUT, "%s.tensor_data_layout" },
|
||||
{ LLM_KV_EXPERT_COUNT, "%s.expert_count" },
|
||||
|
|
@ -133,6 +135,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_EXPERT_WEIGHTS_SCALE, "%s.expert_weights_scale" },
|
||||
{ LLM_KV_EXPERT_WEIGHTS_NORM, "%s.expert_weights_norm" },
|
||||
{ LLM_KV_EXPERT_GATING_FUNC, "%s.expert_gating_func" },
|
||||
{ LLM_KV_EXPERT_GROUP_SCALE, "%s.expert_group_scale" },
|
||||
{ LLM_KV_EXPERTS_PER_GROUP, "%s.experts_per_group" },
|
||||
{ LLM_KV_MOE_EVERY_N_LAYERS, "%s.moe_every_n_layers" },
|
||||
{ LLM_KV_NEXTN_PREDICT_LAYERS, "%s.nextn_predict_layers" },
|
||||
{ LLM_KV_POOLING_TYPE, "%s.pooling_type" },
|
||||
|
|
@ -2186,6 +2190,29 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_GROVEMOE,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
{ LLM_TENSOR_FFN_GATE_CHEXPS, "blk.%d.ffn_gate_chexps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_CHEXPS, "blk.%d.ffn_down_chexps" },
|
||||
{ LLM_TENSOR_FFN_UP_CHEXPS, "blk.%d.ffn_up_chexps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
|
@ -2318,6 +2345,9 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
|
|||
{LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}},
|
||||
{LLM_TENSOR_FFN_EXP_PROBS_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_ADD}},
|
||||
// altup / laurel (gemma 3n)
|
||||
{LLM_TENSOR_PER_LAYER_TOKEN_EMBD, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_GET_ROWS}},
|
||||
|
|
|
|||
|
|
@ -102,6 +102,7 @@ enum llm_arch {
|
|||
LLM_ARCH_LLADA,
|
||||
LLM_ARCH_LLADA_MOE,
|
||||
LLM_ARCH_SEED_OSS,
|
||||
LLM_ARCH_GROVEMOE,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -129,6 +130,7 @@ enum llm_kv {
|
|||
LLM_KV_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_EXPERT_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH,
|
||||
LLM_KV_USE_PARALLEL_RESIDUAL,
|
||||
LLM_KV_TENSOR_DATA_LAYOUT,
|
||||
LLM_KV_EXPERT_COUNT,
|
||||
|
|
@ -137,6 +139,8 @@ enum llm_kv {
|
|||
LLM_KV_EXPERT_WEIGHTS_SCALE,
|
||||
LLM_KV_EXPERT_WEIGHTS_NORM,
|
||||
LLM_KV_EXPERT_GATING_FUNC,
|
||||
LLM_KV_EXPERT_GROUP_SCALE,
|
||||
LLM_KV_EXPERTS_PER_GROUP,
|
||||
LLM_KV_MOE_EVERY_N_LAYERS,
|
||||
LLM_KV_NEXTN_PREDICT_LAYERS,
|
||||
LLM_KV_POOLING_TYPE,
|
||||
|
|
@ -301,6 +305,9 @@ enum llm_tensor {
|
|||
LLM_TENSOR_FFN_DOWN_SHEXP,
|
||||
LLM_TENSOR_FFN_GATE_SHEXP,
|
||||
LLM_TENSOR_FFN_UP_SHEXP,
|
||||
LLM_TENSOR_FFN_DOWN_CHEXPS,
|
||||
LLM_TENSOR_FFN_GATE_CHEXPS,
|
||||
LLM_TENSOR_FFN_UP_CHEXPS,
|
||||
LLM_TENSOR_FFN_EXP_PROBS_B,
|
||||
LLM_TENSOR_ATTN_Q_NORM,
|
||||
LLM_TENSOR_ATTN_K_NORM,
|
||||
|
|
|
|||
|
|
@ -923,13 +923,26 @@ ggml_tensor * llm_graph_context::build_moe_ffn(
|
|||
selection_probs = logits;
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_GROVEMOE) {
|
||||
selection_probs = ggml_sigmoid(ctx0, logits); // [n_expert, n_tokens]
|
||||
cb(selection_probs, "ffn_moe_probs_biased", il);
|
||||
}
|
||||
|
||||
// select experts
|
||||
ggml_tensor * selected_experts = ggml_top_k(ctx0, selection_probs, n_expert_used); // [n_expert_used, n_tokens]
|
||||
cb(selected_experts->src[0], "ffn_moe_argsort", il);
|
||||
cb(selected_experts, "ffn_moe_topk", il);
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0,
|
||||
ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens), selected_experts); // [1, n_expert_used, n_tokens]
|
||||
if (arch == LLM_ARCH_GROVEMOE && n_expert != hparams.n_expert) {
|
||||
// TODO: Use scalar div instead when/if implemented
|
||||
ggml_tensor * f_sel = ggml_cast(ctx0, selected_experts, GGML_TYPE_F32);
|
||||
selected_experts = ggml_cast(ctx0, ggml_scale(ctx0, f_sel, 1.0f / float(hparams.n_group_experts)), GGML_TYPE_I32);
|
||||
probs = ggml_reshape_3d(ctx0, probs, 1, hparams.n_expert, n_tokens);
|
||||
} else {
|
||||
probs = ggml_reshape_3d(ctx0, probs, 1, n_expert, n_tokens);
|
||||
}
|
||||
|
||||
ggml_tensor * weights = ggml_get_rows(ctx0, probs, selected_experts); // [1, n_expert_used, n_tokens]
|
||||
cb(weights, "ffn_moe_weights", il);
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -69,10 +69,13 @@ struct llama_hparams {
|
|||
uint32_t n_lora_kv = 0;
|
||||
uint32_t n_ff_exp = 0;
|
||||
uint32_t n_ff_shexp = 0;
|
||||
uint32_t n_ff_chexp = 0;
|
||||
uint32_t n_expert_shared = 0;
|
||||
uint32_t n_norm_groups = 0;
|
||||
uint32_t n_group_experts = 0;
|
||||
|
||||
float expert_weights_scale = 0.0;
|
||||
float expert_group_scale = 0.05f;
|
||||
float expert_weights_scale = 0.0f;
|
||||
bool expert_weights_norm = false;
|
||||
uint32_t expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_NONE;
|
||||
uint32_t moe_every_n_layers = 0;
|
||||
|
|
|
|||
|
|
@ -2009,6 +2009,19 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GROVEMOE:
|
||||
{
|
||||
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp);
|
||||
ml.get_key(LLM_KV_EXPERT_CHUNK_FEED_FORWARD_LENGTH, hparams.n_ff_chexp);
|
||||
ml.get_key(LLM_KV_EXPERT_GROUP_SCALE, hparams.expert_group_scale);
|
||||
ml.get_key(LLM_KV_EXPERTS_PER_GROUP, hparams.n_group_experts);
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 48: type = LLM_TYPE_30B_A3B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -5840,6 +5853,53 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GROVEMOE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
|
||||
GGML_ASSERT(n_expert > 0 && "n_expert must be > 0 for GROVEMOE");
|
||||
GGML_ASSERT(n_expert_used > 0 && "n_expert_used must be > 0 for GROVEMOE");
|
||||
GGML_ASSERT(hparams.n_group_experts > 0 && "n_group_experts must be > 0 for GROVEMOE");
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0);
|
||||
|
||||
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0);
|
||||
|
||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||
|
||||
// MoE branch
|
||||
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
|
||||
const int64_t n_ff_chexp = hparams.n_ff_chexp ? hparams.n_ff_chexp : n_embd_head_k;
|
||||
const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
|
||||
|
||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0);
|
||||
|
||||
layer.ffn_gate_chexps = create_tensor(tn(LLM_TENSOR_FFN_GATE_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0);
|
||||
layer.ffn_down_chexps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_CHEXPS, "weight", i), {n_ff_chexp, n_embd, n_chunk_expert}, 0);
|
||||
layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
throw std::runtime_error("unknown architecture");
|
||||
}
|
||||
|
|
@ -6179,6 +6239,13 @@ void llama_model::print_info() const {
|
|||
LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func));
|
||||
}
|
||||
|
||||
if (arch == LLM_ARCH_GROVEMOE) {
|
||||
LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp);
|
||||
LLAMA_LOG_INFO("%s: n_ff_chexp = %d\n", __func__, hparams.n_ff_chexp);
|
||||
LLAMA_LOG_INFO("%s: n_group_experts = %d\n", __func__, hparams.n_group_experts);
|
||||
LLAMA_LOG_INFO("%s: expert_group_scale = %.2f\n", __func__, hparams.expert_group_scale);
|
||||
}
|
||||
|
||||
vocab.print_info();
|
||||
}
|
||||
|
||||
|
|
@ -18864,6 +18931,156 @@ struct llm_build_smallthinker : public llm_graph_context{
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_grovemoe : public llm_graph_context {
|
||||
llm_build_grovemoe(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
const int64_t n_chunk_expert = n_expert / hparams.n_group_experts;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self_attention
|
||||
{
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Qcur, "Qcur_normed", il);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il);
|
||||
cb(Kcur, "Kcur_normed", il);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, nullptr,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, cur); // [n_expert, n_tokens]
|
||||
cb(probs, "ffn_moe_logits", il);
|
||||
|
||||
ggml_tensor * moe_out =
|
||||
build_moe_ffn(cur,
|
||||
nullptr,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il, probs);
|
||||
cb(moe_out, "ffn_moe_out", il);
|
||||
cur = moe_out;
|
||||
|
||||
// TODO: Only do the expert selection and weights once
|
||||
moe_out =
|
||||
build_moe_ffn(cur,
|
||||
nullptr,
|
||||
model.layers[il].ffn_up_chexps,
|
||||
model.layers[il].ffn_gate_chexps,
|
||||
model.layers[il].ffn_down_chexps,
|
||||
nullptr,
|
||||
n_chunk_expert, n_expert_used > n_chunk_expert ? n_chunk_expert : n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il, probs);
|
||||
cb(moe_out, "ffn_adj_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ggml_scale(ctx0, moe_out, hparams.expert_group_scale));
|
||||
cb(cur, "ffn_final_moe_out", il);
|
||||
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
};
|
||||
|
||||
llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const {
|
||||
llama_memory_i * res;
|
||||
|
||||
|
|
@ -19390,6 +19607,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
llm = std::make_unique<llm_build_smallthinker<false>>(*this, params);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_GROVEMOE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_grovemoe>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -19595,6 +19816,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_SMALLTHINKER:
|
||||
case LLM_ARCH_GLM4_MOE:
|
||||
case LLM_ARCH_SEED_OSS:
|
||||
case LLM_ARCH_GROVEMOE:
|
||||
return LLAMA_ROPE_TYPE_NEOX;
|
||||
|
||||
case LLM_ARCH_QWEN2VL:
|
||||
|
|
|
|||
|
|
@ -275,6 +275,11 @@ struct llama_layer {
|
|||
struct ggml_tensor * ffn_down_shexp = nullptr;
|
||||
struct ggml_tensor * ffn_up_shexp = nullptr;
|
||||
|
||||
// ff adjugate experts (chexps)
|
||||
struct ggml_tensor * ffn_gate_chexps = nullptr;
|
||||
struct ggml_tensor * ffn_down_chexps = nullptr;
|
||||
struct ggml_tensor * ffn_up_chexps = nullptr;
|
||||
|
||||
// ff bias
|
||||
struct ggml_tensor * ffn_gate_b = nullptr;
|
||||
struct ggml_tensor * ffn_down_b = nullptr; // b2
|
||||
|
|
|
|||
Loading…
Reference in New Issue