model: support Ministral3 (#17644)
* conversion script * support ministral 3 * maybe this is better? * add TODO for rope_yarn_log_mul * better ppl (tested on 14B-Instruct) * Add Ministral3 support to Mistral format * improve arch handling * add sizes * Apply suggestions from code review Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com> * nits --------- Co-authored-by: Julien Denize <julien.denize@mistral.ai> Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
649495c9d9
commit
cd3c118908
|
|
@ -1581,10 +1581,27 @@ class MmprojModel(ModelBase):
|
|||
|
||||
# load preprocessor config
|
||||
self.preprocessor_config = {}
|
||||
if not self.is_mistral_format:
|
||||
with open(self.dir_model / "preprocessor_config.json", "r", encoding="utf-8") as f:
|
||||
|
||||
# prefer preprocessor_config.json if possible
|
||||
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
|
||||
if preprocessor_config_path.is_file():
|
||||
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
|
||||
self.preprocessor_config = json.load(f)
|
||||
|
||||
# prefer processor_config.json if possible
|
||||
processor_config_path = self.dir_model / "processor_config.json"
|
||||
if processor_config_path.is_file():
|
||||
with open(processor_config_path, "r", encoding="utf-8") as f:
|
||||
cfg = json.load(f)
|
||||
# move image_processor to root level for compat
|
||||
if "image_processor" in cfg:
|
||||
cfg = {
|
||||
**cfg,
|
||||
**cfg["image_processor"],
|
||||
}
|
||||
# merge configs
|
||||
self.preprocessor_config = {**self.preprocessor_config, **cfg}
|
||||
|
||||
def get_vision_config(self) -> dict[str, Any] | None:
|
||||
config_name = "vision_config" if not self.is_mistral_format else "vision_encoder"
|
||||
return self.global_config.get(config_name)
|
||||
|
|
@ -2797,7 +2814,32 @@ class Llama4VisionModel(MmprojModel):
|
|||
|
||||
@ModelBase.register("Mistral3ForConditionalGeneration")
|
||||
class Mistral3Model(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone has migrated to newer version of llama.cpp
|
||||
if self.hparams.get("model_type") != "ministral3":
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
rope_params = self.hparams.get("rope_parameters")
|
||||
if self.hparams.get("model_type") == "ministral3":
|
||||
assert rope_params is not None, "ministral3 must have 'rope_parameters' config"
|
||||
assert rope_params["rope_type"] == "yarn", "ministral3 rope_type must be 'yarn'"
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(rope_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(rope_params["beta_fast"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(rope_params["beta_slow"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(rope_params["mscale_all_dim"])
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(rope_params["original_max_position_embeddings"])
|
||||
self.gguf_writer.add_rope_freq_base(rope_params["rope_theta"])
|
||||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||
name = name.replace("language_model.", "")
|
||||
|
|
@ -9809,12 +9851,22 @@ class ApertusModel(LlamaModel):
|
|||
|
||||
|
||||
class MistralModel(LlamaModel):
|
||||
model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
model_arch = gguf.MODEL_ARCH.MISTRAL3
|
||||
model_name = "Mistral"
|
||||
hf_arch = ""
|
||||
is_mistral_format = True
|
||||
undo_permute = False
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
# for compatibility, we use LLAMA arch for older models
|
||||
# TODO: remove this once everyone migrates to newer version of llama.cpp
|
||||
if "llama_4_scaling" not in self.hparams:
|
||||
self.model_arch = gguf.MODEL_ARCH.LLAMA
|
||||
self.gguf_writer.arch = gguf.MODEL_ARCH_NAMES[self.model_arch]
|
||||
self.gguf_writer.add_architecture()
|
||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||
|
||||
@staticmethod
|
||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||
|
|
@ -9854,6 +9906,20 @@ class MistralModel(LlamaModel):
|
|||
|
||||
return template
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
if "yarn" in self.hparams:
|
||||
yarn_params = self.hparams["yarn"]
|
||||
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.YARN)
|
||||
self.gguf_writer.add_rope_scaling_factor(yarn_params["factor"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_fast(yarn_params["beta"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_beta_slow(yarn_params["alpha"])
|
||||
self.gguf_writer.add_rope_scaling_yarn_log_mul(1.0) # mscale_all_dim
|
||||
self.gguf_writer.add_rope_scaling_orig_ctx_len(yarn_params["original_max_position_embeddings"])
|
||||
|
||||
if "llama_4_scaling" in self.hparams:
|
||||
self.gguf_writer.add_attn_temperature_scale(self.hparams["llama_4_scaling"]["beta"])
|
||||
|
||||
|
||||
class PixtralModel(LlavaVisionModel):
|
||||
model_name = "Pixtral"
|
||||
|
|
|
|||
|
|
@ -175,6 +175,7 @@ class Keys:
|
|||
VALUE_LENGTH_MLA = "{arch}.attention.value_length_mla"
|
||||
SHARED_KV_LAYERS = "{arch}.attention.shared_kv_layers"
|
||||
SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern"
|
||||
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
|
||||
|
||||
class Rope:
|
||||
DIMENSION_COUNT = "{arch}.rope.dimension_count"
|
||||
|
|
@ -444,6 +445,7 @@ class MODEL_ARCH(IntEnum):
|
|||
MINIMAXM2 = auto()
|
||||
RND1 = auto()
|
||||
PANGU_EMBED = auto()
|
||||
MISTRAL3 = auto()
|
||||
|
||||
|
||||
class VISION_PROJECTOR_TYPE(IntEnum):
|
||||
|
|
@ -817,6 +819,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.COGVLM: "cogvlm",
|
||||
MODEL_ARCH.RND1: "rnd1",
|
||||
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
|
||||
MODEL_ARCH.MISTRAL3: "mistral3",
|
||||
}
|
||||
|
||||
VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = {
|
||||
|
|
@ -3071,6 +3074,26 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.MISTRAL3: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ROPE_FREQS,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.ATTN_ROT_EMBD,
|
||||
MODEL_TENSOR.FFN_GATE_INP,
|
||||
MODEL_TENSOR.FFN_NORM,
|
||||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_GATE_EXP,
|
||||
MODEL_TENSOR.FFN_DOWN_EXP,
|
||||
MODEL_TENSOR.FFN_UP_EXP,
|
||||
],
|
||||
# TODO
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -904,6 +904,9 @@ class GGUFWriter:
|
|||
def add_attn_temperature_length(self, value: int) -> None:
|
||||
self.add_uint32(Keys.Attention.TEMPERATURE_LENGTH.format(arch=self.arch), value)
|
||||
|
||||
def add_attn_temperature_scale(self, value: float) -> None:
|
||||
self.add_float32(Keys.Attention.TEMPERATURE_SCALE.format(arch=self.arch), value)
|
||||
|
||||
def add_pooling_type(self, value: PoolingType) -> None:
|
||||
self.add_uint32(Keys.LLM.POOLING_TYPE.format(arch=self.arch), value.value)
|
||||
|
||||
|
|
|
|||
|
|
@ -132,6 +132,7 @@ add_library(llama
|
|||
models/t5-enc.cpp
|
||||
models/wavtokenizer-dec.cpp
|
||||
models/xverse.cpp
|
||||
models/mistral3.cpp
|
||||
models/graph-context-mamba.cpp
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_COGVLM, "cogvlm" },
|
||||
{ LLM_ARCH_RND1, "rnd1" },
|
||||
{ LLM_ARCH_PANGU_EMBED, "pangu-embedded" },
|
||||
{ LLM_ARCH_MISTRAL3, "mistral3" },
|
||||
{ LLM_ARCH_UNKNOWN, "(unknown)" },
|
||||
};
|
||||
|
||||
|
|
@ -204,6 +205,7 @@ static const std::map<llm_kv, const char *> LLM_KV_NAMES = {
|
|||
{ LLM_KV_ATTENTION_SCALE, "%s.attention.scale" },
|
||||
{ LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" },
|
||||
{ LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" },
|
||||
{ LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" },
|
||||
{ LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" },
|
||||
|
||||
|
|
@ -2512,6 +2514,32 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_MISTRAL3,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ROPE_FREQS, "rope_freqs" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
{ LLM_TENSOR_ATTN_ROT_EMBD, "blk.%d.attn_rot_embd" },
|
||||
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
|
||||
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
|
||||
{ LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" },
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXP, "blk.%d.ffn_gate.%d" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" },
|
||||
{ LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" },
|
||||
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
|
||||
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
|
||||
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_UNKNOWN,
|
||||
{
|
||||
|
|
|
|||
|
|
@ -115,6 +115,7 @@ enum llm_arch {
|
|||
LLM_ARCH_COGVLM,
|
||||
LLM_ARCH_RND1,
|
||||
LLM_ARCH_PANGU_EMBED,
|
||||
LLM_ARCH_MISTRAL3,
|
||||
LLM_ARCH_UNKNOWN,
|
||||
};
|
||||
|
||||
|
|
@ -208,6 +209,7 @@ enum llm_kv {
|
|||
LLM_KV_ATTENTION_SCALE,
|
||||
LLM_KV_ATTENTION_OUTPUT_SCALE,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_LENGTH,
|
||||
LLM_KV_ATTENTION_TEMPERATURE_SCALE,
|
||||
LLM_KV_ATTENTION_KEY_LENGTH_MLA,
|
||||
LLM_KV_ATTENTION_VALUE_LENGTH_MLA,
|
||||
|
||||
|
|
|
|||
|
|
@ -71,6 +71,9 @@ void llm_graph_input_attn_temp::set_input(const llama_ubatch * ubatch) {
|
|||
if (ubatch->pos && attn_scale) {
|
||||
const int64_t n_tokens = ubatch->n_tokens;
|
||||
|
||||
GGML_ASSERT(f_attn_temp_scale != 0.0f);
|
||||
GGML_ASSERT(n_attn_temp_floor_scale != 0);
|
||||
|
||||
std::vector<float> attn_scale_data(n_tokens, 0.0f);
|
||||
for (int i = 0; i < n_tokens; ++i) {
|
||||
const float pos = ubatch->pos[i];
|
||||
|
|
|
|||
|
|
@ -162,8 +162,8 @@ struct llama_hparams {
|
|||
// llama4 smallthinker
|
||||
uint32_t n_moe_layer_step = 0;
|
||||
uint32_t n_no_rope_layer_step = 4;
|
||||
uint32_t n_attn_temp_floor_scale = 8192;
|
||||
float f_attn_temp_scale = 0.1;
|
||||
uint32_t n_attn_temp_floor_scale = 0;
|
||||
float f_attn_temp_scale = 0.0f;
|
||||
|
||||
// gemma3n altup
|
||||
uint32_t n_altup = 4; // altup_num_inputs
|
||||
|
|
|
|||
|
|
@ -626,8 +626,6 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
switch (arch) {
|
||||
case LLM_ARCH_LLAMA:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
if (hparams.n_expert == 8) {
|
||||
switch (hparams.n_layer) {
|
||||
case 32: type = LLM_TYPE_8x7B; break;
|
||||
|
|
@ -665,6 +663,8 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
} else {
|
||||
hparams.swa_type = LLAMA_SWA_TYPE_CHUNKED;
|
||||
hparams.n_swa = 8192;
|
||||
hparams.n_attn_temp_floor_scale = 8192;
|
||||
hparams.f_attn_temp_scale = 0.1f;
|
||||
hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full
|
||||
}
|
||||
|
||||
|
|
@ -2247,6 +2247,42 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
ml.get_key(LLM_KV_ATTENTION_TEMPERATURE_SCALE, hparams.f_attn_temp_scale, false);
|
||||
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_FAST, hparams.yarn_beta_fast, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_BETA_SLOW, hparams.yarn_beta_slow, false);
|
||||
ml.get_key(LLM_KV_ROPE_SCALING_YARN_LOG_MUL, hparams.rope_yarn_log_mul, false);
|
||||
|
||||
// TODO: maybe add n_attn_temp_floor_scale as a separate KV?
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
hparams.n_attn_temp_floor_scale = hparams.n_ctx_orig_yarn;
|
||||
if (hparams.n_attn_temp_floor_scale == 0) {
|
||||
throw std::runtime_error("invalid n_ctx_orig_yarn for attention temperature scaling");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: this seems to be correct with the case of mscale == mscale_all_dims == 1.0f
|
||||
// but may need further verification with other values
|
||||
if (hparams.rope_yarn_log_mul != 0.0f) {
|
||||
float factor = 1.0f / hparams.rope_freq_scale_train;
|
||||
float mscale = 1.0f;
|
||||
float mscale_all_dims = hparams.rope_yarn_log_mul;
|
||||
static auto get_mscale = [](float scale, float mscale) {
|
||||
return scale <= 1.0f ? 1.0f : (0.1f * mscale * logf(scale) + 1.0f);
|
||||
};
|
||||
hparams.yarn_attn_factor = get_mscale(factor, mscale) / get_mscale(factor, mscale_all_dims);
|
||||
}
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 26: type = LLM_TYPE_3B; break;
|
||||
case 34: type = LLM_TYPE_8B; break;
|
||||
case 40: type = LLM_TYPE_14B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
default: throw std::runtime_error("unsupported model architecture");
|
||||
}
|
||||
|
||||
|
|
@ -2560,6 +2596,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
case LLM_ARCH_MINICPM:
|
||||
case LLM_ARCH_GRANITE:
|
||||
case LLM_ARCH_GRANITE_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
|
|
@ -7522,6 +7559,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_qwen3next>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
{
|
||||
llm = std::make_unique<llm_build_mistral3>(*this, params);
|
||||
} break;
|
||||
default:
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
|
|
@ -7690,6 +7731,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_ARCEE:
|
||||
case LLM_ARCH_ERNIE4_5:
|
||||
case LLM_ARCH_ERNIE4_5_MOE:
|
||||
case LLM_ARCH_MISTRAL3:
|
||||
return LLAMA_ROPE_TYPE_NORM;
|
||||
|
||||
// the pairs of head values are offset by n_rot/2
|
||||
|
|
|
|||
|
|
@ -0,0 +1,160 @@
|
|||
#include "models.h"
|
||||
|
||||
llm_build_mistral3::llm_build_mistral3(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
GGML_ASSERT(n_embd_head == hparams.n_rot);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
// inp_pos - contains the positions
|
||||
ggml_tensor * inp_pos = build_inp_pos();
|
||||
|
||||
// (optional) temperature tuning
|
||||
ggml_tensor * inp_attn_scale = nullptr;
|
||||
if (hparams.f_attn_temp_scale != 0.0f) {
|
||||
inp_attn_scale = build_inp_attn_scale();
|
||||
}
|
||||
|
||||
auto * inp_attn = build_attn_inp_kv();
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
// self-attention
|
||||
{
|
||||
// rope freq factors for llama3; may return nullptr for llama2 and other models
|
||||
ggml_tensor * rope_factors = model.get_rope_factors(cparams, il);
|
||||
|
||||
// compute Q and K and RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens);
|
||||
|
||||
Qcur = ggml_rope_ext(
|
||||
ctx0, Qcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
Kcur = ggml_rope_ext(
|
||||
ctx0, Kcur, inp_pos, rope_factors,
|
||||
n_rot, rope_type, n_ctx_orig, freq_base, freq_scale,
|
||||
ext_factor, attn_factor, beta_fast, beta_slow
|
||||
);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
if (inp_attn_scale) {
|
||||
// apply llama 4 temperature scaling
|
||||
Qcur = ggml_mul(ctx0, Qcur, inp_attn_scale);
|
||||
cb(Qcur, "Qcur_attn_temp_scaled", il);
|
||||
}
|
||||
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
}
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA);
|
||||
cb(ffn_inp, "ffn_inp", il);
|
||||
|
||||
// feed-forward network (non-MoE)
|
||||
if (model.layers[il].ffn_gate_inp == nullptr) {
|
||||
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
model.layers[il].ffn_gate, model.layers[il].ffn_gate_b, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_SILU, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
} else {
|
||||
// MoE branch
|
||||
cur = build_norm(ffn_inp,
|
||||
model.layers[il].ffn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "ffn_norm", il);
|
||||
|
||||
cur = build_moe_ffn(cur,
|
||||
model.layers[il].ffn_gate_inp,
|
||||
model.layers[il].ffn_up_exps,
|
||||
model.layers[il].ffn_gate_exps,
|
||||
model.layers[il].ffn_down_exps,
|
||||
nullptr,
|
||||
n_expert, n_expert_used,
|
||||
LLM_FFN_SILU, true,
|
||||
false, 0.0,
|
||||
LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX,
|
||||
il);
|
||||
cb(cur, "ffn_moe_out", il);
|
||||
}
|
||||
cur = ggml_add(ctx0, cur, ffn_inp);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
|
@ -322,6 +322,10 @@ struct llm_build_minimax_m2 : public llm_graph_context {
|
|||
llm_build_minimax_m2(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mistral3 : public llm_graph_context {
|
||||
llm_build_mistral3(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
||||
struct llm_build_mpt : public llm_graph_context {
|
||||
llm_build_mpt(const llama_model & model, const llm_graph_params & params);
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue