model : add grok-2 support (#15539)
* add grok-2 support * type fix * type fix * type fix * "fix" vocab for invalid sequences * fix expert tensor mapping and spaces in vocab * add chat template * fix norm tensor mapping * rename layer_out_norm to ffn_post_norm * ensure ffn_post_norm is mapped * fix experts merging * remove erroneous FFN_GATE entry * concatenate split tensors and add more metadata * process all expert layers and try cat instead of hstack * add support for community BPE vocab * fix expert feed forward length and ffn_down concat * commit this too * add ffn_up/gate/down, unsure if sequence is right * add ffn_gate/down/up to tensor names * correct residual moe (still not working) * mess-- * fix embedding scale being applied twice * add built in chat template * change beta fast for grok if default value * remove spm vocab in favor of community bpe vocab * change attention temp length metadata type to integer * update attention temp length metadata * remove comment * replace M_SQRT2 with std::sqrt(2) * add yarn metadata, move defaults to hparams
This commit is contained in:
parent
6c019cb04e
commit
b8e09f08b9
|
|
@ -288,9 +288,9 @@ struct common_params {
|
||||||
float rope_freq_base = 0.0f; // RoPE base frequency
|
float rope_freq_base = 0.0f; // RoPE base frequency
|
||||||
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
float rope_freq_scale = 0.0f; // RoPE frequency scaling factor
|
||||||
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
float yarn_ext_factor = -1.0f; // YaRN extrapolation mix factor
|
||||||
float yarn_attn_factor = 1.0f; // YaRN magnitude scaling factor
|
float yarn_attn_factor = -1.0f; // YaRN magnitude scaling factor
|
||||||
float yarn_beta_fast = 32.0f; // YaRN low correction dim
|
float yarn_beta_fast = -1.0f; // YaRN low correction dim
|
||||||
float yarn_beta_slow = 1.0f; // YaRN high correction dim
|
float yarn_beta_slow = -1.0f; // YaRN high correction dim
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
|
|
||||||
// offload params
|
// offload params
|
||||||
|
|
|
||||||
|
|
@ -735,6 +735,9 @@ class TextModel(ModelBase):
|
||||||
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
|
if chkhsh == "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c":
|
||||||
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
|
# ref: https://huggingface.co/Qwen/Qwen3-Embedding-0.6B
|
||||||
res = "qwen2"
|
res = "qwen2"
|
||||||
|
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
|
||||||
|
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
|
||||||
|
res = "grok-2"
|
||||||
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
|
||||||
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
|
||||||
res = "llama-bpe"
|
res = "llama-bpe"
|
||||||
|
|
@ -2682,12 +2685,20 @@ class BitnetModel(TextModel):
|
||||||
yield (new_name, data_torch)
|
yield (new_name, data_torch)
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("GrokForCausalLM")
|
@ModelBase.register("GrokForCausalLM", "Grok1ForCausalLM")
|
||||||
class GrokModel(TextModel):
|
class GrokModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GROK
|
model_arch = gguf.MODEL_ARCH.GROK
|
||||||
|
|
||||||
def set_vocab(self):
|
def set_vocab(self):
|
||||||
|
if (self.dir_model / 'tokenizer.model').is_file():
|
||||||
self._set_vocab_sentencepiece()
|
self._set_vocab_sentencepiece()
|
||||||
|
return
|
||||||
|
|
||||||
|
if not (self.dir_model / 'tokenizer.json').is_file() or not (self.dir_model / 'chat_template.jinja').is_file():
|
||||||
|
logger.error('Error: Missing vocab and chat template, download files from https://huggingface.co/alvarobartt/grok-2-tokenizer')
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
self._set_vocab_gpt2()
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
@ -2695,11 +2706,46 @@ class GrokModel(TextModel):
|
||||||
def set_gguf_parameters(self):
|
def set_gguf_parameters(self):
|
||||||
super().set_gguf_parameters()
|
super().set_gguf_parameters()
|
||||||
|
|
||||||
_experts: list[dict[str, Tensor]] | None = None
|
self.gguf_writer.add_attn_logit_softcapping(self.hparams.get("attn_logit_softcapping", 30.0))
|
||||||
|
self.gguf_writer.add_router_logit_softcapping(self.hparams.get("router_logit_softcapping", 30.0))
|
||||||
|
if (final_logit_softcap := self.hparams.get("final_logit_softcapping")):
|
||||||
|
self.gguf_writer.add_final_logit_softcapping(final_logit_softcap)
|
||||||
|
|
||||||
|
if (rope_dim := self.hparams.get("head_dim")) is None:
|
||||||
|
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||||
|
|
||||||
|
if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None:
|
||||||
|
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size)
|
||||||
|
|
||||||
|
# Treat "original" as "yarn", seems to have been a mistake
|
||||||
|
if self.hparams.get("rope_type") in ("yarn", "original"):
|
||||||
|
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||||
|
self.gguf_writer.add_rope_scaling_factor(self.hparams["scaling_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_orig_ctx_len(self.hparams["original_max_position_embeddings"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_ext_factor(self.hparams["extrapolation_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_attn_factor(self.hparams["attn_factor"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_beta_fast(self.hparams["beta_fast"])
|
||||||
|
self.gguf_writer.add_rope_scaling_yarn_beta_slow(self.hparams["beta_slow"])
|
||||||
|
|
||||||
|
if temp_len := self.hparams.get("attn_temperature_len"):
|
||||||
|
self.gguf_writer.add_attn_temperature_length(temp_len)
|
||||||
|
|
||||||
|
self.gguf_writer.add_attn_output_scale(self.hparams.get("attn_output_multiplier", rope_dim**-0.5))
|
||||||
|
self.gguf_writer.add_embedding_scale(self.hparams["embedding_multiplier_scale"])
|
||||||
|
self.gguf_writer.add_logit_scale(self.hparams["output_multiplier_scale"])
|
||||||
|
|
||||||
|
_experts: list[dict[str, list[Tensor]]] | None = None
|
||||||
|
_cur_expert = ""
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
|
tensors: list[tuple[str, Tensor]] = []
|
||||||
|
is_expert = ".moe." in name or ".block_sparse_moe.experts." in name
|
||||||
|
|
||||||
|
if not is_expert:
|
||||||
|
tensors.append((self.map_tensor_name(name), data_torch))
|
||||||
|
|
||||||
# process the experts separately
|
# process the experts separately
|
||||||
if name.find(".moe.") != -1:
|
if is_expert or self._cur_expert:
|
||||||
n_experts = self.hparams["num_local_experts"]
|
n_experts = self.hparams["num_local_experts"]
|
||||||
|
|
||||||
assert bid is not None
|
assert bid is not None
|
||||||
|
|
@ -2707,32 +2753,41 @@ class GrokModel(TextModel):
|
||||||
if self._experts is None:
|
if self._experts is None:
|
||||||
self._experts = [{} for _ in range(self.block_count)]
|
self._experts = [{} for _ in range(self.block_count)]
|
||||||
|
|
||||||
self._experts[bid][name] = data_torch
|
# concatenate split tensors
|
||||||
|
if name in self._experts[bid]:
|
||||||
|
self._cur_expert = name
|
||||||
|
self._experts[bid][name].append(data_torch)
|
||||||
|
return []
|
||||||
|
elif is_expert:
|
||||||
|
self._cur_expert = name
|
||||||
|
self._experts[bid][name] = [data_torch]
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
self._cur_expert = ""
|
||||||
|
|
||||||
|
for bid in range(self.block_count):
|
||||||
if len(self._experts[bid]) >= n_experts * 3:
|
if len(self._experts[bid]) >= n_experts * 3:
|
||||||
tensors: list[tuple[str, Tensor]] = []
|
|
||||||
|
|
||||||
# merge the experts into a single 3d tensor
|
# merge the experts into a single 3d tensor
|
||||||
for wid in ["linear", "linear_1", "linear_v"]:
|
for wid in [("linear", "w1", 0), ("linear_1", "w2", 1), ("linear_v", "w3", 0)]:
|
||||||
datas: list[Tensor] = []
|
datas: list[Tensor] = []
|
||||||
|
|
||||||
for xid in range(n_experts):
|
for xid in range(n_experts):
|
||||||
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid}.weight"
|
ename = f"transformer.decoder_layer.{bid}.moe.{xid}.{wid[0]}.weight"
|
||||||
datas.append(self._experts[bid][ename])
|
if ename not in self._experts[bid]:
|
||||||
|
ename = f"model.layers.{bid}.block_sparse_moe.experts.{xid}.{wid[1]}.weight"
|
||||||
|
tensor_list = self._experts[bid][ename]
|
||||||
|
datas.append(torch.cat(tensor_list, dim=wid[2]) if len(tensor_list) > 1 else tensor_list[0])
|
||||||
del self._experts[bid][ename]
|
del self._experts[bid][ename]
|
||||||
|
|
||||||
data_torch = torch.stack(datas, dim=0)
|
data_torch = torch.stack(datas, dim=0)
|
||||||
|
|
||||||
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid}.weight"
|
merged_name = f"transformer.decoder_layer.{bid}.moe.{wid[0]}.weight"
|
||||||
|
|
||||||
new_name = self.map_tensor_name(merged_name)
|
new_name = self.map_tensor_name(merged_name)
|
||||||
|
|
||||||
tensors.append((new_name, data_torch))
|
yield (new_name, data_torch)
|
||||||
return tensors
|
|
||||||
else:
|
|
||||||
return []
|
|
||||||
|
|
||||||
return [(self.map_tensor_name(name), data_torch)]
|
yield from tensors
|
||||||
|
|
||||||
|
|
||||||
@ModelBase.register("DbrxForCausalLM")
|
@ModelBase.register("DbrxForCausalLM")
|
||||||
|
|
|
||||||
|
|
@ -158,6 +158,7 @@ pre_computed_hashes = [
|
||||||
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
{"name": "falcon-h1", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tiiuae/Falcon-H1-34B-Base", "chkhsh": "48f8e02c0359c0bbdd82f26909171fac1c18a457bb47573ed1fe3bbb2c1cfd4b"},
|
||||||
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
|
||||||
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
|
||||||
|
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -111,6 +111,7 @@ class Keys:
|
||||||
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
|
||||||
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
|
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
|
||||||
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
|
||||||
|
ROUTER_LOGIT_SOFTCAPPING = "{arch}.router_logit_softcapping"
|
||||||
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
|
||||||
SWIN_NORM = "{arch}.swin_norm"
|
SWIN_NORM = "{arch}.swin_norm"
|
||||||
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
|
RESCALE_EVERY_N_LAYERS = "{arch}.rescale_every_n_layers"
|
||||||
|
|
@ -146,6 +147,8 @@ class Keys:
|
||||||
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
REL_BUCKETS_COUNT = "{arch}.attention.relative_buckets_count"
|
||||||
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
SLIDING_WINDOW = "{arch}.attention.sliding_window"
|
||||||
SCALE = "{arch}.attention.scale"
|
SCALE = "{arch}.attention.scale"
|
||||||
|
OUTPUT_SCALE = "{arch}.attention.output_scale"
|
||||||
|
TEMPERATURE_LENGTH = "{arch}.attention.temperature_length"
|
||||||
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
KEY_LENGTH_MLA = "{arch}.attention.key_length_mla"
|
||||||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
||||||
|
|
@ -161,6 +164,10 @@ class Keys:
|
||||||
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
|
||||||
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
|
||||||
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
|
||||||
|
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
|
||||||
|
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
|
||||||
|
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
|
||||||
|
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
|
||||||
|
|
||||||
class Split:
|
class Split:
|
||||||
LLM_KV_SPLIT_NO = "split.no"
|
LLM_KV_SPLIT_NO = "split.no"
|
||||||
|
|
@ -1114,6 +1121,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||||
MODEL_TENSOR.FFN_GATE_EXP,
|
MODEL_TENSOR.FFN_GATE_EXP,
|
||||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||||
MODEL_TENSOR.FFN_UP_EXP,
|
MODEL_TENSOR.FFN_UP_EXP,
|
||||||
|
MODEL_TENSOR.FFN_POST_NORM,
|
||||||
MODEL_TENSOR.LAYER_OUT_NORM,
|
MODEL_TENSOR.LAYER_OUT_NORM,
|
||||||
],
|
],
|
||||||
MODEL_ARCH.GPTNEOX: [
|
MODEL_ARCH.GPTNEOX: [
|
||||||
|
|
|
||||||
|
|
@ -733,6 +733,9 @@ class GGUFWriter:
|
||||||
def add_attn_logit_softcapping(self, value: float) -> None:
|
def add_attn_logit_softcapping(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.ATTN_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_router_logit_softcapping(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.LLM.ROUTER_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_final_logit_softcapping(self, value: float) -> None:
|
def add_final_logit_softcapping(self, value: float) -> None:
|
||||||
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
self.add_float32(Keys.LLM.FINAL_LOGIT_SOFTCAPPING.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
@ -829,6 +832,12 @@ class GGUFWriter:
|
||||||
def add_attention_scale(self, value: float) -> None:
|
def add_attention_scale(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
self.add_float32(Keys.Attention.SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_attn_output_scale(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Attention.OUTPUT_SCALE.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_attn_temperature_length(self, value: int) -> None:
|
||||||
|
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_pooling_type(self, value: PoolingType) -> None:
|
def add_pooling_type(self, value: PoolingType) -> None:
|
||||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||||
|
|
||||||
|
|
@ -859,6 +868,18 @@ class GGUFWriter:
|
||||||
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
def add_rope_scaling_yarn_log_mul(self, value: float) -> None:
|
||||||
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
self.add_float32(Keys.Rope.SCALING_YARN_LOG_MUL.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_ext_factor(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_EXT_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_attn_factor(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_ATTN_FACTOR.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_beta_fast(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_BETA_FAST.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
def add_rope_scaling_yarn_beta_slow(self, value: float) -> None:
|
||||||
|
self.add_float32(Keys.Rope.SCALING_YARN_BETA_SLOW.format(arch=self.arch), value)
|
||||||
|
|
||||||
def add_ssm_conv_kernel(self, value: int) -> None:
|
def add_ssm_conv_kernel(self, value: int) -> None:
|
||||||
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
self.add_uint32(Keys.SSM.CONV_KERNEL.format(arch=self.arch), value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -136,6 +136,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.norm", # mamba-qbert
|
"model.layers.{bid}.norm", # mamba-qbert
|
||||||
"backbone.layers.{bid}.norm", # mamba
|
"backbone.layers.{bid}.norm", # mamba
|
||||||
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm", # Grok
|
||||||
|
"model.layers.{bid}.pre_attn_norm", # grok-2
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.norm_1", # dbrx
|
||||||
"encoder.layers.{bid}.input_layernorm", # chatglm
|
"encoder.layers.{bid}.input_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.attn_norm", # openelm
|
"transformer.layers.{bid}.attn_norm", # openelm
|
||||||
|
|
@ -278,6 +279,7 @@ class TensorNameMap:
|
||||||
"transformer.layer.{bid}.sa_layer_norm", # distillbert
|
"transformer.layer.{bid}.sa_layer_norm", # distillbert
|
||||||
"encoder.layers.{bid}.norm1", # nomic-bert
|
"encoder.layers.{bid}.norm1", # nomic-bert
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_1", # Grok
|
||||||
|
"model.layers.{bid}.post_attn_norm", # grok-2
|
||||||
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
"transformer.blocks.{bid}.norm_attn_norm.norm_2", # dbrx
|
||||||
),
|
),
|
||||||
|
|
||||||
|
|
@ -313,6 +315,7 @@ class TensorNameMap:
|
||||||
"h.{bid}.ln_2", # gpt2
|
"h.{bid}.ln_2", # gpt2
|
||||||
"model.layers.{bid}.ffn_norm", # internlm2
|
"model.layers.{bid}.ffn_norm", # internlm2
|
||||||
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
"transformer.decoder_layer.{bid}.rms_norm_2", # Grok
|
||||||
|
"model.layers.{bid}.pre_moe_norm", # grok-2
|
||||||
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
"encoder.layers.{bid}.post_attention_layernorm", # chatglm
|
||||||
"transformer.layers.{bid}.ffn_norm", # openelm
|
"transformer.layers.{bid}.ffn_norm", # openelm
|
||||||
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
|
"model.layers.{bid}.pre_ff_layernorm", # jamba granite-hybrid
|
||||||
|
|
@ -338,6 +341,7 @@ class TensorNameMap:
|
||||||
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
"model.layers.{bid}.post_mlp_layernorm", # glm-4-0414
|
||||||
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
|
"model.layers.layers.{bid}.post_mlp_norm.weight", # plamo2
|
||||||
"model.layers.{bid}.feed_forward.up_proj",
|
"model.layers.{bid}.feed_forward.up_proj",
|
||||||
|
"model.layers.{bid}.post_moe_norm", # grok-2
|
||||||
),
|
),
|
||||||
|
|
||||||
MODEL_TENSOR.FFN_GATE_INP: (
|
MODEL_TENSOR.FFN_GATE_INP: (
|
||||||
|
|
|
||||||
|
|
@ -139,6 +139,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
{ LLM_KV_DECODER_START_TOKEN_ID, "%s.decoder_start_token_id" },
|
||||||
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
|
{ LLM_KV_DECODER_BLOCK_COUNT, "%s.decoder_block_count" },
|
||||||
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
{ LLM_KV_ATTN_LOGIT_SOFTCAPPING, "%s.attn_logit_softcapping" },
|
||||||
|
{ LLM_KV_ROUTER_LOGIT_SOFTCAPPING, "%s.router_logit_softcapping" },
|
||||||
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
{ LLM_KV_FINAL_LOGIT_SOFTCAPPING, "%s.final_logit_softcapping" },
|
||||||
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
|
{ LLM_KV_SWIN_NORM, "%s.swin_norm" },
|
||||||
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
|
{ LLM_KV_RESCALE_EVERY_N_LAYERS, "%s.rescale_every_n_layers" },
|
||||||
|
|
@ -169,6 +170,8 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
{ LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT, "%s.attention.relative_buckets_count" },
|
||||||
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
{ LLM_KV_ATTENTION_SLIDING_WINDOW, "%s.attention.sliding_window" },
|
||||||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||||
|
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||||
|
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||||
|
|
||||||
|
|
@ -182,6 +185,10 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
||||||
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
{ LLM_KV_ROPE_SCALING_ORIG_CTX_LEN, "%s.rope.scaling.original_context_length" },
|
||||||
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
{ LLM_KV_ROPE_SCALING_FINETUNED, "%s.rope.scaling.finetuned" },
|
||||||
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
{ LLM_KV_ROPE_SCALING_YARN_LOG_MUL, "%s.rope.scaling.yarn_log_multiplier" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, "%s.rope.scaling.yarn_ext_factor" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, "%s.rope.scaling.yarn_attn_factor" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_BETA_FAST, "%s.rope.scaling.yarn_beta_fast" },
|
||||||
|
{ LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, "%s.rope.scaling.yarn_beta_slow" },
|
||||||
|
|
||||||
{ LLM_KV_SPLIT_NO, "split.no" },
|
{ LLM_KV_SPLIT_NO, "split.no" },
|
||||||
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
{ LLM_KV_SPLIT_COUNT, "split.count" },
|
||||||
|
|
@ -398,12 +405,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
||||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||||
|
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||||
|
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||||
|
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||||
|
{ LLM_TENSOR_FFN_POST_NORM, "blk.%d.post_ffw_norm" },
|
||||||
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
{ LLM_TENSOR_LAYER_OUT_NORM, "blk.%d.layer_output_norm" },
|
||||||
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
{ LLM_TENSOR_ATTN_OUT_NORM, "blk.%d.attn_output_norm" },
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ enum llm_kv {
|
||||||
LLM_KV_DECODER_START_TOKEN_ID,
|
LLM_KV_DECODER_START_TOKEN_ID,
|
||||||
LLM_KV_DECODER_BLOCK_COUNT,
|
LLM_KV_DECODER_BLOCK_COUNT,
|
||||||
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
LLM_KV_ATTN_LOGIT_SOFTCAPPING,
|
||||||
|
LLM_KV_ROUTER_LOGIT_SOFTCAPPING,
|
||||||
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
LLM_KV_FINAL_LOGIT_SOFTCAPPING,
|
||||||
LLM_KV_SWIN_NORM,
|
LLM_KV_SWIN_NORM,
|
||||||
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
LLM_KV_RESCALE_EVERY_N_LAYERS,
|
||||||
|
|
@ -173,6 +174,8 @@ enum llm_kv {
|
||||||
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
LLM_KV_ATTENTION_RELATIVE_BUCKETS_COUNT,
|
||||||
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
LLM_KV_ATTENTION_SLIDING_WINDOW,
|
||||||
LLM_KV_ATTENTION_SCALE,
|
LLM_KV_ATTENTION_SCALE,
|
||||||
|
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||||
|
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||||
|
|
||||||
|
|
@ -186,6 +189,10 @@ enum llm_kv {
|
||||||
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
LLM_KV_ROPE_SCALING_ORIG_CTX_LEN,
|
||||||
LLM_KV_ROPE_SCALING_FINETUNED,
|
LLM_KV_ROPE_SCALING_FINETUNED,
|
||||||
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
LLM_KV_ROPE_SCALING_YARN_LOG_MUL,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_BETA_FAST,
|
||||||
|
LLM_KV_ROPE_SCALING_YARN_BETA_SLOW,
|
||||||
|
|
||||||
LLM_KV_SPLIT_NO,
|
LLM_KV_SPLIT_NO,
|
||||||
LLM_KV_SPLIT_COUNT,
|
LLM_KV_SPLIT_COUNT,
|
||||||
|
|
|
||||||
|
|
@ -70,6 +70,7 @@ static const std::map<std::string, llm_chat_template> LLM_CHAT_TEMPLATES = {
|
||||||
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
{ "hunyuan-dense", LLM_CHAT_TEMPLATE_HUNYUAN_DENSE },
|
||||||
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
{ "kimi-k2", LLM_CHAT_TEMPLATE_KIMI_K2 },
|
||||||
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
{ "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS },
|
||||||
|
{ "grok-2", LLM_CHAT_TEMPLATE_GROK_2 },
|
||||||
};
|
};
|
||||||
|
|
||||||
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
llm_chat_template llm_chat_template_from_str(const std::string & name) {
|
||||||
|
|
@ -204,6 +205,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) {
|
||||||
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
return LLM_CHAT_TEMPLATE_KIMI_K2;
|
||||||
} else if (tmpl_contains("<seed:bos>")) {
|
} else if (tmpl_contains("<seed:bos>")) {
|
||||||
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
return LLM_CHAT_TEMPLATE_SEED_OSS;
|
||||||
|
} else if (tmpl_contains("'Assistant: ' + message['content'] + '<|separator|>")) {
|
||||||
|
return LLM_CHAT_TEMPLATE_GROK_2;
|
||||||
}
|
}
|
||||||
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
return LLM_CHAT_TEMPLATE_UNKNOWN;
|
||||||
}
|
}
|
||||||
|
|
@ -763,6 +766,20 @@ int32_t llm_chat_apply_template(
|
||||||
if (add_ass) {
|
if (add_ass) {
|
||||||
ss << "<seed:bos>assistant\n";
|
ss << "<seed:bos>assistant\n";
|
||||||
}
|
}
|
||||||
|
} else if (tmpl == LLM_CHAT_TEMPLATE_GROK_2) {
|
||||||
|
for (auto message : chat) {
|
||||||
|
std::string role(message->role);
|
||||||
|
if (role == "system") {
|
||||||
|
ss << "System: " << trim(message->content) << "<|separator|>\n\n";
|
||||||
|
} else if (role == "user") {
|
||||||
|
ss << "Human: " << trim(message->content) << "<|separator|>\n\n";
|
||||||
|
} else if (role == "assistant") {
|
||||||
|
ss << "Assistant: " << message->content << "<|separator|>\n\n";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (add_ass) {
|
||||||
|
ss << "Assistant:";
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// template not supported
|
// template not supported
|
||||||
return -1;
|
return -1;
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ enum llm_chat_template {
|
||||||
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
LLM_CHAT_TEMPLATE_HUNYUAN_DENSE,
|
||||||
LLM_CHAT_TEMPLATE_KIMI_K2,
|
LLM_CHAT_TEMPLATE_KIMI_K2,
|
||||||
LLM_CHAT_TEMPLATE_SEED_OSS,
|
LLM_CHAT_TEMPLATE_SEED_OSS,
|
||||||
|
LLM_CHAT_TEMPLATE_GROK_2,
|
||||||
LLM_CHAT_TEMPLATE_UNKNOWN,
|
LLM_CHAT_TEMPLATE_UNKNOWN,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,10 +35,10 @@ llama_context::llama_context(
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
cparams.yarn_ext_factor = params.yarn_ext_factor >= 0.0f ? params.yarn_ext_factor : hparams.yarn_ext_factor;
|
||||||
cparams.yarn_attn_factor = params.yarn_attn_factor;
|
cparams.yarn_attn_factor = params.yarn_attn_factor >= 0.0f ? params.yarn_attn_factor : hparams.yarn_attn_factor;
|
||||||
cparams.yarn_beta_fast = params.yarn_beta_fast;
|
cparams.yarn_beta_fast = params.yarn_beta_fast >= 0.0f ? params.yarn_beta_fast : hparams.yarn_beta_fast;
|
||||||
cparams.yarn_beta_slow = params.yarn_beta_slow;
|
cparams.yarn_beta_slow = params.yarn_beta_slow >= 0.0f ? params.yarn_beta_slow : hparams.yarn_beta_slow;
|
||||||
cparams.embeddings = params.embeddings;
|
cparams.embeddings = params.embeddings;
|
||||||
cparams.offload_kqv = params.offload_kqv;
|
cparams.offload_kqv = params.offload_kqv;
|
||||||
cparams.no_perf = params.no_perf;
|
cparams.no_perf = params.no_perf;
|
||||||
|
|
@ -2263,9 +2263,9 @@ llama_context_params llama_context_default_params() {
|
||||||
/*.rope_freq_base =*/ 0.0f,
|
/*.rope_freq_base =*/ 0.0f,
|
||||||
/*.rope_freq_scale =*/ 0.0f,
|
/*.rope_freq_scale =*/ 0.0f,
|
||||||
/*.yarn_ext_factor =*/ -1.0f,
|
/*.yarn_ext_factor =*/ -1.0f,
|
||||||
/*.yarn_attn_factor =*/ 1.0f,
|
/*.yarn_attn_factor =*/ -1.0f,
|
||||||
/*.yarn_beta_fast =*/ 32.0f,
|
/*.yarn_beta_fast =*/ -1.0f,
|
||||||
/*.yarn_beta_slow =*/ 1.0f,
|
/*.yarn_beta_slow =*/ -1.0f,
|
||||||
/*.yarn_orig_ctx =*/ 0,
|
/*.yarn_orig_ctx =*/ 0,
|
||||||
/*.defrag_thold =*/ -1.0f,
|
/*.defrag_thold =*/ -1.0f,
|
||||||
/*.cb_eval =*/ nullptr,
|
/*.cb_eval =*/ nullptr,
|
||||||
|
|
|
||||||
|
|
@ -1335,14 +1335,14 @@ ggml_tensor * llm_graph_context::build_attn_mha(
|
||||||
|
|
||||||
if (arch == LLM_ARCH_GROK) {
|
if (arch == LLM_ARCH_GROK) {
|
||||||
// need to do the following:
|
// need to do the following:
|
||||||
// multiply by attn_output_multiplyer of 0.08838834764831845
|
// multiply by attn_output_multiplier
|
||||||
// and then :
|
// and then :
|
||||||
// kq = 30 * tanh(kq / 30)
|
// kq = 30 * tanh(kq / 30)
|
||||||
// before the softmax below
|
// before the softmax below
|
||||||
|
|
||||||
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, 0.08838834764831845f/30.0f));
|
kq = ggml_tanh(ctx0, ggml_scale(ctx0, kq, hparams.f_attn_out_scale / hparams.f_attn_logit_softcapping));
|
||||||
cb(kq, "kq_tanh", il);
|
cb(kq, "kq_tanh", il);
|
||||||
kq = ggml_scale(ctx0, kq, 30);
|
kq = ggml_scale(ctx0, kq, hparams.f_attn_logit_softcapping);
|
||||||
cb(kq, "kq_scaled", il);
|
cb(kq, "kq_scaled", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,6 +83,7 @@ struct llama_hparams {
|
||||||
float f_norm_group_eps;
|
float f_norm_group_eps;
|
||||||
|
|
||||||
float f_attn_logit_softcapping = 50.0f;
|
float f_attn_logit_softcapping = 50.0f;
|
||||||
|
float f_router_logit_softcapping = 30.0f;
|
||||||
float f_final_logit_softcapping = 30.0f;
|
float f_final_logit_softcapping = 30.0f;
|
||||||
|
|
||||||
// for RWKV
|
// for RWKV
|
||||||
|
|
@ -104,6 +105,11 @@ struct llama_hparams {
|
||||||
uint32_t n_ctx_orig_yarn;
|
uint32_t n_ctx_orig_yarn;
|
||||||
float rope_yarn_log_mul = 0.0f;
|
float rope_yarn_log_mul = 0.0f;
|
||||||
|
|
||||||
|
float yarn_ext_factor = -1.0f;
|
||||||
|
float yarn_attn_factor = 1.0f;
|
||||||
|
float yarn_beta_fast = 32.0f;
|
||||||
|
float yarn_beta_slow = 1.0f;
|
||||||
|
|
||||||
std::array<int, 4> rope_sections;
|
std::array<int, 4> rope_sections;
|
||||||
|
|
||||||
// Sliding Window Attention (SWA)
|
// Sliding Window Attention (SWA)
|
||||||
|
|
@ -136,6 +142,10 @@ struct llama_hparams {
|
||||||
float f_embedding_scale = 0.0f;
|
float f_embedding_scale = 0.0f;
|
||||||
float f_attention_scale = 0.0f;
|
float f_attention_scale = 0.0f;
|
||||||
|
|
||||||
|
// grok-2
|
||||||
|
float f_attn_out_scale = 0.0f;
|
||||||
|
uint32_t attn_temp_length = 0;
|
||||||
|
|
||||||
bool causal_attn = true;
|
bool causal_attn = true;
|
||||||
bool use_alibi = false;
|
bool use_alibi = false;
|
||||||
bool attn_soft_cap = false;
|
bool attn_soft_cap = false;
|
||||||
|
|
|
||||||
|
|
@ -685,7 +685,30 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_GROK:
|
case LLM_ARCH_GROK:
|
||||||
{
|
{
|
||||||
|
// defaults for old GGUFs
|
||||||
|
hparams.yarn_beta_fast = 8.0f;
|
||||||
|
hparams.f_logit_scale = 0.5773502691896257f;
|
||||||
|
hparams.f_embedding_scale = 78.38367176906169f;
|
||||||
|
hparams.f_attn_out_scale = 0.08838834764831845f;
|
||||||
|
hparams.f_attn_logit_softcapping = 30.0f;
|
||||||
|
hparams.f_router_logit_softcapping = 30.0f;
|
||||||
|
// no final_logit_softcapping in grok-1
|
||||||
|
hparams.f_final_logit_softcapping = 0.0f;
|
||||||
|
|
||||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||||
|
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
|
||||||
|
ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale, false);
|
||||||
|
ml.get_key(LLM_KV_EMBEDDING_SCALE, hparams.f_embedding_scale, false);
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_OUTPUT_SCALE, hparams.f_attn_out_scale, false);
|
||||||
|
ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false);
|
||||||
|
ml.get_key(LLM_KV_ROUTER_LOGIT_SOFTCAPPING, hparams.f_router_logit_softcapping, false);
|
||||||
|
ml.get_key(LLM_KV_FINAL_LOGIT_SOFTCAPPING, hparams.f_final_logit_softcapping, false);
|
||||||
|
|
||||||
|
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_LENGTH, hparams.attn_temp_length, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_EXT_FACTOR, hparams.yarn_ext_factor, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_ATTN_FACTOR, hparams.yarn_attn_factor, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||||
|
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||||
|
|
||||||
switch (hparams.n_layer) {
|
switch (hparams.n_layer) {
|
||||||
case 64: type = LLM_TYPE_314B; break;
|
case 64: type = LLM_TYPE_314B; break;
|
||||||
|
|
@ -2540,6 +2563,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff/* / n_expert_used*/; // grok-1 n_ff_exp == n_ff
|
||||||
for (int i = 0; i < n_layer; ++i) {
|
for (int i = 0; i < n_layer; ++i) {
|
||||||
auto & layer = layers[i];
|
auto & layer = layers[i];
|
||||||
|
|
||||||
|
|
@ -2554,12 +2578,19 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
||||||
|
|
||||||
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
|
||||||
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||||
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff, n_expert}, TENSOR_NOT_REQUIRED);
|
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff, n_embd, n_expert}, 0);
|
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED);
|
||||||
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff, n_expert}, 0);
|
|
||||||
|
|
||||||
layer.layer_out_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, 0);
|
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, TENSOR_NOT_REQUIRED);
|
||||||
|
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0);
|
||||||
|
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, n_ff_exp, n_expert}, 0);
|
||||||
|
|
||||||
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_LAYER_OUT_NORM, "weight", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||||
|
if (!layer.ffn_post_norm) {
|
||||||
|
layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case LLM_ARCH_DBRX:
|
case LLM_ARCH_DBRX:
|
||||||
|
|
@ -7028,9 +7059,6 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
|
|
||||||
inpL = build_inp_embd(model.tok_embd);
|
inpL = build_inp_embd(model.tok_embd);
|
||||||
|
|
||||||
// multiply by embedding_multiplier_scale of 78.38367176906169
|
|
||||||
inpL = ggml_scale(ctx0, inpL, 78.38367176906169f);
|
|
||||||
|
|
||||||
// inp_pos - contains the positions
|
// inp_pos - contains the positions
|
||||||
ggml_tensor * inp_pos = build_inp_pos();
|
ggml_tensor * inp_pos = build_inp_pos();
|
||||||
|
|
||||||
|
|
@ -7102,26 +7130,22 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Grok
|
|
||||||
// if attn_out_norm is present then apply it before adding the input
|
|
||||||
if (model.layers[il].attn_out_norm) {
|
|
||||||
cur = build_norm(cur,
|
cur = build_norm(cur,
|
||||||
model.layers[il].attn_out_norm, NULL,
|
model.layers[il].attn_out_norm, NULL,
|
||||||
LLM_NORM_RMS, il);
|
LLM_NORM_RMS, il);
|
||||||
cb(cur, "attn_out_norm", il);
|
cb(cur, "attn_out_norm", il);
|
||||||
}
|
|
||||||
|
|
||||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||||
cb(ffn_inp, "ffn_inp", il);
|
cb(ffn_inp, "ffn_inp", il);
|
||||||
|
|
||||||
// feed-forward network
|
// feed-forward network
|
||||||
// MoE branch
|
|
||||||
cur = build_norm(ffn_inp,
|
cur = build_norm(ffn_inp,
|
||||||
model.layers[il].ffn_norm, NULL,
|
model.layers[il].ffn_norm, NULL,
|
||||||
LLM_NORM_RMS, il);
|
LLM_NORM_RMS, il);
|
||||||
cb(cur, "ffn_norm", il);
|
cb(cur, "ffn_norm", il);
|
||||||
|
|
||||||
cur = build_moe_ffn(cur,
|
// MoE branch
|
||||||
|
ggml_tensor * moe_out = build_moe_ffn(cur,
|
||||||
model.layers[il].ffn_gate_inp,
|
model.layers[il].ffn_gate_inp,
|
||||||
model.layers[il].ffn_up_exps,
|
model.layers[il].ffn_up_exps,
|
||||||
model.layers[il].ffn_gate_exps,
|
model.layers[il].ffn_gate_exps,
|
||||||
|
|
@ -7132,18 +7156,28 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
false, 0.0,
|
false, 0.0,
|
||||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||||
il);
|
il);
|
||||||
cb(cur, "ffn_moe_out", il);
|
cb(moe_out, "ffn_moe_out", il);
|
||||||
|
|
||||||
// Grok
|
if (model.layers[il].ffn_up) {
|
||||||
// if layer_out_norm is present then apply it before adding the input
|
ggml_tensor * ffn_out = build_ffn(cur,
|
||||||
// Idea: maybe ffn_out_norm is a better name
|
model.layers[il].ffn_up, NULL, NULL,
|
||||||
if (model.layers[il].layer_out_norm) {
|
model.layers[il].ffn_gate, NULL, NULL,
|
||||||
cur = build_norm(cur,
|
model.layers[il].ffn_down, NULL, NULL,
|
||||||
model.layers[il].layer_out_norm, NULL,
|
NULL,
|
||||||
LLM_NORM_RMS, il);
|
LLM_FFN_GELU, LLM_FFN_PAR, il);
|
||||||
cb(cur, "layer_out_norm", il);
|
cb(ffn_out, "ffn_out", il);
|
||||||
|
|
||||||
|
cur = ggml_scale(ctx0, ggml_add(ctx0, ffn_out, moe_out), std::sqrt(2) / 2);
|
||||||
|
cb(cur, "ffn_out", il);
|
||||||
|
} else {
|
||||||
|
cur = moe_out;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
cur = build_norm(cur,
|
||||||
|
model.layers[il].ffn_post_norm, NULL,
|
||||||
|
LLM_NORM_RMS, il);
|
||||||
|
cb(cur, "ffn_post_norm", il);
|
||||||
|
|
||||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||||
cb(cur, "ffn_out", il);
|
cb(cur, "ffn_out", il);
|
||||||
|
|
||||||
|
|
@ -7166,10 +7200,14 @@ struct llm_build_grok : public llm_graph_context {
|
||||||
// lm_head
|
// lm_head
|
||||||
cur = build_lora_mm(model.output, cur);
|
cur = build_lora_mm(model.output, cur);
|
||||||
|
|
||||||
// Grok
|
cur = ggml_scale(ctx0, cur, hparams.f_logit_scale);
|
||||||
// multiply logits by output_multiplier_scale of 0.5773502691896257
|
|
||||||
|
|
||||||
cur = ggml_scale(ctx0, cur, 0.5773502691896257f);
|
// final logit soft-capping
|
||||||
|
if (hparams.f_final_logit_softcapping) {
|
||||||
|
cur = ggml_scale(ctx0, cur, 1.0f / hparams.f_final_logit_softcapping);
|
||||||
|
cur = ggml_tanh(ctx0, cur);
|
||||||
|
cur = ggml_scale(ctx0, cur, hparams.f_final_logit_softcapping);
|
||||||
|
}
|
||||||
|
|
||||||
cb(cur, "result_output", -1);
|
cb(cur, "result_output", -1);
|
||||||
res->t_logits = cur;
|
res->t_logits = cur;
|
||||||
|
|
|
||||||
|
|
@ -434,6 +434,13 @@ struct llm_tokenizer_bpe : llm_tokenizer {
|
||||||
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1}| ?[^\\s\\p{L}\\p{N}\\r\\n]+|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
};
|
};
|
||||||
break;
|
break;
|
||||||
|
case LLAMA_VOCAB_PRE_TYPE_GROK_2:
|
||||||
|
regex_exprs = {
|
||||||
|
// original regex from tokenizer.json
|
||||||
|
// "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
|
||||||
|
"(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+",
|
||||||
|
};
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
// default regex for BPE tokenization pre-processing
|
// default regex for BPE tokenization pre-processing
|
||||||
regex_exprs = {
|
regex_exprs = {
|
||||||
|
|
@ -1974,6 +1981,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) {
|
||||||
tokenizer_pre == "kimi-k2") {
|
tokenizer_pre == "kimi-k2") {
|
||||||
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2;
|
||||||
clean_spaces = false;
|
clean_spaces = false;
|
||||||
|
} else if (
|
||||||
|
tokenizer_pre == "grok-2") {
|
||||||
|
pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2;
|
||||||
|
clean_spaces = false;
|
||||||
} else {
|
} else {
|
||||||
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str()));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ enum llama_vocab_pre_type {
|
||||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36,
|
||||||
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37,
|
||||||
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
|
LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38,
|
||||||
|
LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39,
|
||||||
};
|
};
|
||||||
|
|
||||||
struct LLM_KV;
|
struct LLM_KV;
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue