nvidia nemotron nano v2 (nemotronh) (#15507)
* feat: Add NEMOTRONH to python arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add NEMOTRONH to c++ arch enum https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add NEMOTRONH to llama-arch layer map https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First pass at conversion for nemotronh https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: Add a verbose log for each tensor loaded This is really helpful for diagnosing mismatches between the expected and received tensors https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * feat: First (broken) pass at nemotronh model architecture It generates tokens, just not valid ones! https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Explicitly enable add_bos_token during conversion The `tokenizer.json`/`tokenizer_config.json` in the model are a bit contradictory. In the config, add_bos_token is set to False, but the tokenizer model itself has a post_processor that adds the BOS token via type: TemplateProcessing https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use relu2 (LLM_FFN_RELU_SQR) for activation in FFN layers https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Only allocate attention cache for attention layers (not non-recurrent) https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Move residual add to after every block https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * fix: Use the correct norm tensor for the MLP blocks https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> * Nemotron-H: MLP gate cleanup (pass NULL for unused gate) This model does not use a gate in MLP blocks; pass NULLs for gate tensors to make intent clear and avoid unused-pointer noise. * SSM: respect ssm_dt_rank for dt_dim when provided Use GGUF-provided time_step_rank (ssm_dt_rank) to set dt_dim when > 0; fallback to max(64, n_embd/16). * fix: plamo2 - revert dt_dim to default (remove ssm_dt_rank usage) * Rename nemotronh to nemotron_h for consistency - Update architecture name from NEMOTRONH to NEMOTRON_H in constants.py - Change architecture string from 'nemotronh' to 'nemotron_h' in all files - Update enum LLM_ARCH_NEMOTRONH to LLM_ARCH_NEMOTRON_H - Update class name llm_build_nemotronh to llm_build_nemotron_h - Consistent naming with underscore convention (nemotron_h vs nemotronh) * feat: Support conversion for older NemotronH models https://github.com/ggml-org/llama.cpp/issues/nemotron-nano-15409 Branch: gabe-l-hart/nvidia-nemotron-nano-15409 Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> --------- Signed-off-by: Gabe Goodhart <ghart@us.ibm.com> Co-authored-by: Maicon Domingues <dominguesm@outlook.com> Co-authored-by: weatherman <fxdstudios@gmail.com>
This commit is contained in:
parent
a8bca68f72
commit
e8d99dd0b6
|
|
@ -7546,9 +7546,13 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
|||
]
|
||||
|
||||
# n_group and d_inner are used during reshape_tensors for mamba2
|
||||
self.d_model = self.find_hparam(["hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups"])
|
||||
self.d_inner = self.find_hparam(["expand"]) * self.d_model
|
||||
# NOTE: Explicitly include hparam prefix prefix for d_model to
|
||||
# disambiguate with top-level head_dim
|
||||
# NOTE 2: If needed for future models, this can be isolated in a method
|
||||
# to separate the prefix setting and teh keys used
|
||||
self.d_model = self.find_hparam([f"{self.hparam_prefixes[0]}_head_dim", "hidden_size", "d_model"])
|
||||
self.n_group = self.find_hparam(["n_groups", "num_groups"])
|
||||
self.d_inner = self.find_hparam(["expand", "num_heads"]) * self.d_model
|
||||
|
||||
def get_attn_layers(self):
|
||||
# Explicit list of layer type names
|
||||
|
|
@ -7609,12 +7613,12 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
|||
|
||||
## Mamba mixer params ##
|
||||
self.gguf_writer.add_ssm_conv_kernel(self.find_hparam(["conv_kernel", "d_conv"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state"]))
|
||||
self.gguf_writer.add_ssm_state_size(self.find_hparam(["state_size", "d_state", "state_dim", "ssm_state_size"]))
|
||||
self.gguf_writer.add_ssm_group_count(self.n_group)
|
||||
self.gguf_writer.add_ssm_inner_size(self.d_inner)
|
||||
# NOTE: The mamba_dt_rank is _not_ the right field for how this is used
|
||||
# in llama.cpp
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads"]))
|
||||
self.gguf_writer.add_ssm_time_step_rank(self.find_hparam(["n_heads", "num_heads"]))
|
||||
|
||||
## Attention params ##
|
||||
head_count_kv = self.find_hparam(["num_key_value_heads", "n_head_kv"])
|
||||
|
|
@ -7641,6 +7645,55 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel):
|
|||
Mamba2Model.set_vocab(self)
|
||||
|
||||
|
||||
@ModelBase.register("NemotronHForCausalLM")
|
||||
class NemotronHModel(GraniteHybridModel):
|
||||
"""Hybrid mamba2/attention model from NVIDIA"""
|
||||
model_arch = gguf.MODEL_ARCH.NEMOTRON_H
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
# Save the top-level head_dim for later
|
||||
self.head_dim = self.hparams.get("head_dim", self.hparams.get("attention_head_dim"))
|
||||
assert self.head_dim is not None, "Could not find the attention head dim in config"
|
||||
|
||||
# Don't use expand to calculate d_inner
|
||||
self.d_inner = self.find_hparam(["num_heads"]) * self.d_model
|
||||
|
||||
# Update the ssm / attn / mlp layers
|
||||
# M: Mamba2, *: Attention, -: MLP
|
||||
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||
self._ssm_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "M"]
|
||||
self._mlp_layers = [i for i, val in enumerate(hybrid_override_pattern) if val == "-"]
|
||||
|
||||
def get_attn_layers(self):
|
||||
hybrid_override_pattern = self.hparams["hybrid_override_pattern"]
|
||||
assert len(hybrid_override_pattern) == self.block_count, "Mismatch between hybrid override and num_hidden_layers!"
|
||||
return [i for i, val in enumerate(hybrid_override_pattern) if val == "*"]
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
super().set_gguf_parameters()
|
||||
|
||||
self.gguf_writer.add_key_length(self.head_dim)
|
||||
self.gguf_writer.add_value_length(self.head_dim)
|
||||
|
||||
# Set feed_forward_length
|
||||
# NOTE: This will trigger an override warning. This is preferrable to
|
||||
# duplicating all the parent logic
|
||||
n_ff = self.find_hparam(["intermediate_size", "n_inner", "hidden_dim"])
|
||||
self.gguf_writer.add_feed_forward_length([
|
||||
n_ff if i in self._mlp_layers else 0 for i in range(self.block_count)
|
||||
])
|
||||
|
||||
def set_vocab(self):
|
||||
super().set_vocab()
|
||||
|
||||
# The tokenizer _does_ add a BOS token (via post_processor type
|
||||
# TemplateProcessing) but does not set add_bos_token to true in the
|
||||
# config, so we need to explicitly override it here.
|
||||
self.gguf_writer.add_add_bos_token(True)
|
||||
|
||||
|
||||
@ModelBase.register("BailingMoeForCausalLM")
|
||||
class BailingMoeModel(TextModel):
|
||||
model_arch = gguf.MODEL_ARCH.BAILINGMOE
|
||||
|
|
|
|||
|
|
@ -367,6 +367,7 @@ class MODEL_ARCH(IntEnum):
|
|||
T5ENCODER = auto()
|
||||
JAIS = auto()
|
||||
NEMOTRON = auto()
|
||||
NEMOTRON_H = auto()
|
||||
EXAONE = auto()
|
||||
EXAONE4 = auto()
|
||||
GRANITE = auto()
|
||||
|
|
@ -700,6 +701,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
|||
MODEL_ARCH.T5ENCODER: "t5encoder",
|
||||
MODEL_ARCH.JAIS: "jais",
|
||||
MODEL_ARCH.NEMOTRON: "nemotron",
|
||||
MODEL_ARCH.NEMOTRON_H: "nemotron_h",
|
||||
MODEL_ARCH.EXAONE: "exaone",
|
||||
MODEL_ARCH.EXAONE4: "exaone4",
|
||||
MODEL_ARCH.GRANITE: "granite",
|
||||
|
|
@ -2297,6 +2299,25 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.NEMOTRON_H: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
MODEL_TENSOR.OUTPUT,
|
||||
MODEL_TENSOR.ATTN_NORM,
|
||||
MODEL_TENSOR.SSM_IN,
|
||||
MODEL_TENSOR.SSM_CONV1D,
|
||||
MODEL_TENSOR.SSM_DT,
|
||||
MODEL_TENSOR.SSM_A,
|
||||
MODEL_TENSOR.SSM_D,
|
||||
MODEL_TENSOR.SSM_NORM,
|
||||
MODEL_TENSOR.SSM_OUT,
|
||||
MODEL_TENSOR.ATTN_Q,
|
||||
MODEL_TENSOR.ATTN_K,
|
||||
MODEL_TENSOR.ATTN_V,
|
||||
MODEL_TENSOR.ATTN_OUT,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
],
|
||||
MODEL_ARCH.EXAONE: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
MODEL_TENSOR.OUTPUT_NORM,
|
||||
|
|
|
|||
|
|
@ -191,6 +191,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.self_attn.q_proj", # llama4
|
||||
"model.transformer.blocks.{bid}.q_proj", # llada
|
||||
"layers.{bid}.self_attn.q_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.q_proj", # nemotron-h
|
||||
),
|
||||
|
||||
# Attention key
|
||||
|
|
@ -209,6 +210,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.self_attn.k_proj", # llama4
|
||||
"model.transformer.blocks.{bid}.k_proj", # llada
|
||||
"layers.{bid}.self_attn.k_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.k_proj", # nemotron-h
|
||||
),
|
||||
|
||||
# Attention value
|
||||
|
|
@ -226,6 +228,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.self_attn.v_proj", # llama4
|
||||
"model.transformer.blocks.{bid}.v_proj", # llada
|
||||
"layers.{bid}.self_attn.v_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.v_proj", # nemotron-h
|
||||
),
|
||||
|
||||
# Attention output
|
||||
|
|
@ -260,6 +263,7 @@ class TensorNameMap:
|
|||
"transformer_encoder.{bid}.wo", # neobert
|
||||
"model.transformer.blocks.{bid}.attn_out", # llada
|
||||
"layers.{bid}.self_attn.o_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.o_proj", # nemotron-h
|
||||
),
|
||||
|
||||
# Attention output norm
|
||||
|
|
@ -387,6 +391,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.up", # smallthinker
|
||||
"model.transformer.blocks.{bid}.up_proj", # llada
|
||||
"layers.{bid}.mlp.up_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.up_proj", # nemotron-h
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_UP_EXP: (
|
||||
|
|
@ -480,6 +485,7 @@ class TensorNameMap:
|
|||
"model.layers.{bid}.block_sparse_moe.down", # smallthinker
|
||||
"model.transformer.blocks.{bid}.ff_out", # llada
|
||||
"layers.{bid}.mlp.down_proj", # qwen3-embedding
|
||||
"backbone.layers.{bid}.mixer.down_proj", # nemotron-h
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FFN_DOWN_EXP: (
|
||||
|
|
|
|||
|
|
@ -69,6 +69,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
|
|||
{ LLM_ARCH_T5ENCODER, "t5encoder" },
|
||||
{ LLM_ARCH_JAIS, "jais" },
|
||||
{ LLM_ARCH_NEMOTRON, "nemotron" },
|
||||
{ LLM_ARCH_NEMOTRON_H, "nemotron_h" },
|
||||
{ LLM_ARCH_EXAONE, "exaone" },
|
||||
{ LLM_ARCH_EXAONE4, "exaone4" },
|
||||
{ LLM_ARCH_RWKV6, "rwkv6" },
|
||||
|
|
@ -1550,6 +1551,31 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
|
|||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_NEMOTRON_H,
|
||||
{
|
||||
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
|
||||
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
|
||||
{ LLM_TENSOR_OUTPUT, "output" },
|
||||
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
|
||||
// mamba(2) ssm layers
|
||||
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
|
||||
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
|
||||
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
|
||||
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
|
||||
{ LLM_TENSOR_SSM_D, "blk.%d.ssm_d" },
|
||||
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
|
||||
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
|
||||
// attention layers
|
||||
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
|
||||
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
|
||||
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
|
||||
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
|
||||
// dense FFN
|
||||
{ LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" },
|
||||
{ LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" },
|
||||
},
|
||||
},
|
||||
{
|
||||
LLM_ARCH_EXAONE,
|
||||
{
|
||||
|
|
@ -2355,6 +2381,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
|
|||
case LLM_ARCH_PLAMO2:
|
||||
case LLM_ARCH_GRANITE_HYBRID:
|
||||
case LLM_ARCH_LFM2:
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ enum llm_arch {
|
|||
LLM_ARCH_T5ENCODER,
|
||||
LLM_ARCH_JAIS,
|
||||
LLM_ARCH_NEMOTRON,
|
||||
LLM_ARCH_NEMOTRON_H,
|
||||
LLM_ARCH_EXAONE,
|
||||
LLM_ARCH_EXAONE4,
|
||||
LLM_ARCH_RWKV6,
|
||||
|
|
|
|||
|
|
@ -788,6 +788,7 @@ const struct ggml_tensor * llama_model_loader::check_tensor_dims(const std::stri
|
|||
}
|
||||
|
||||
struct ggml_tensor * llama_model_loader::create_tensor(struct ggml_context * ctx, const std::string & name, const std::initializer_list<int64_t> & ne, int flags) {
|
||||
LLAMA_LOG_DEBUG("%s: loading tensor %s\n", __func__, name.c_str());
|
||||
const struct ggml_tensor * cur = check_tensor_dims(name, ne, !(flags & TENSOR_NOT_REQUIRED));
|
||||
|
||||
if (cur == NULL) {
|
||||
|
|
|
|||
|
|
@ -1570,6 +1570,27 @@ void llama_model::load_hparams(llama_model_loader & ml) {
|
|||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
{
|
||||
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
|
||||
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
|
||||
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
|
||||
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
|
||||
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
|
||||
|
||||
// A layer is recurrent IFF the n_head_kv value is set to 0 and
|
||||
// the n_ff value is set to 0
|
||||
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
|
||||
hparams.recurrent_layer_arr[i] = (hparams.n_head_kv(i) == 0 && hparams.n_ff(i) == 0);
|
||||
}
|
||||
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
||||
switch (hparams.n_layer) {
|
||||
case 56: type = LLM_TYPE_9B; break;
|
||||
default: type = LLM_TYPE_UNKNOWN;
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_EXAONE:
|
||||
{
|
||||
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
|
||||
|
|
@ -4688,6 +4709,75 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
|
|||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
{
|
||||
// mamba2 Mixer SSM params
|
||||
// NOTE: int64_t for tensor dimensions
|
||||
const int64_t d_conv = hparams.ssm_d_conv;
|
||||
const int64_t d_inner = hparams.ssm_d_inner;
|
||||
const int64_t d_state = hparams.ssm_d_state;
|
||||
const int64_t n_ssm_head = hparams.ssm_dt_rank;
|
||||
const int64_t n_group = hparams.ssm_n_group;
|
||||
const int64_t d_in_proj = 2*d_inner + 2*n_group*d_state + n_ssm_head;
|
||||
|
||||
// embeddings
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
||||
// output
|
||||
{
|
||||
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
|
||||
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
|
||||
// if output is NULL, init from the input tok embed, duplicated to allow offloading
|
||||
if (output == NULL) {
|
||||
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_layer; ++i) {
|
||||
auto & layer = layers[i];
|
||||
|
||||
// all blocks use the attn norm
|
||||
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0);
|
||||
|
||||
if (hparams.is_recurrent(i)) {
|
||||
// ssm layers
|
||||
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), {n_embd, d_in_proj}, 0);
|
||||
|
||||
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), {d_conv, d_inner + 2*n_group*d_state}, 0);
|
||||
layer.ssm_conv1d_b = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "bias", i), {d_inner + 2*n_group*d_state}, TENSOR_NOT_REQUIRED);
|
||||
|
||||
layer.ssm_dt_b = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), {n_ssm_head}, 0);
|
||||
|
||||
// no "weight" suffix for these
|
||||
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), {1, n_ssm_head}, 0);
|
||||
layer.ssm_d = create_tensor(tn(LLM_TENSOR_SSM_D, i), {1, n_ssm_head}, 0);
|
||||
|
||||
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), {d_inner / n_group, n_group}, 0);
|
||||
|
||||
// out_proj
|
||||
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), {d_inner, n_embd}, 0);
|
||||
} else if (hparams.n_ff(i) == 0) {
|
||||
// attention layers (with optional bias)
|
||||
const int64_t n_head_i = hparams.n_head(i);
|
||||
const int64_t n_embd_k_gqa_i = hparams.n_embd_k_gqa(i);
|
||||
const int64_t n_embd_v_gqa_i = hparams.n_embd_v_gqa(i);
|
||||
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head_i}, 0);
|
||||
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa_i}, 0);
|
||||
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa_i}, 0);
|
||||
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head_i, n_embd}, 0);
|
||||
layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_k_gqa_i}, TENSOR_NOT_REQUIRED);
|
||||
layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_v_gqa_i}, TENSOR_NOT_REQUIRED);
|
||||
layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
} else {
|
||||
// mlp layers
|
||||
layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { hparams.n_ff(i), n_embd}, 0);
|
||||
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, hparams.n_ff(i)}, 0);
|
||||
layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED);
|
||||
layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {hparams.n_ff(i)}, TENSOR_NOT_REQUIRED);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case LLM_ARCH_EXAONE:
|
||||
{
|
||||
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
|
||||
|
|
@ -5862,7 +5952,8 @@ void llama_model::print_info() const {
|
|||
arch == LLM_ARCH_JAMBA ||
|
||||
arch == LLM_ARCH_FALCON_H1 ||
|
||||
arch == LLM_ARCH_PLAMO2 ||
|
||||
arch == LLM_ARCH_GRANITE_HYBRID) {
|
||||
arch == LLM_ARCH_GRANITE_HYBRID ||
|
||||
arch == LLM_ARCH_NEMOTRON_H) {
|
||||
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
|
||||
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
|
||||
LLAMA_LOG_INFO("%s: ssm_d_state = %u\n", __func__, hparams.ssm_d_state);
|
||||
|
|
@ -14129,6 +14220,138 @@ struct llm_build_nemotron : public llm_graph_context {
|
|||
}
|
||||
};
|
||||
|
||||
struct llm_build_nemotron_h : public llm_graph_context_mamba {
|
||||
llm_build_nemotron_h(
|
||||
const llama_model & model,
|
||||
const llm_graph_params & params) :
|
||||
llm_graph_context_mamba(params) {
|
||||
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
GGML_ASSERT(n_embd_head == hparams.n_embd_head_k);
|
||||
|
||||
ggml_tensor * cur;
|
||||
ggml_tensor * inpL;
|
||||
|
||||
inpL = build_inp_embd(model.tok_embd);
|
||||
|
||||
auto * inp = build_inp_mem_hybrid();
|
||||
|
||||
ggml_tensor * inp_out_ids = build_inp_out_ids();
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
// norm
|
||||
cur = build_norm(inpL,
|
||||
model.layers[il].attn_norm, NULL,
|
||||
LLM_NORM_RMS, il);
|
||||
cb(cur, "attn_norm", il);
|
||||
|
||||
if (hparams.is_recurrent(il)) {
|
||||
// ssm layer //
|
||||
cur = build_mamba2_layer(inp->get_recr(), cur, model, ubatch, il);
|
||||
} else if (hparams.n_ff(il) == 0) {
|
||||
// attention layer //
|
||||
cur = build_attention_layer(cur, inp->get_attn(), model, n_embd_head, il);
|
||||
} else {
|
||||
cur = build_ffn_layer(cur, model, il);
|
||||
}
|
||||
|
||||
if (il == n_layer - 1 && inp_out_ids) {
|
||||
cur = ggml_get_rows(ctx0, cur, inp_out_ids);
|
||||
inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids);
|
||||
}
|
||||
|
||||
// add residual
|
||||
cur = ggml_add(ctx0, cur, inpSA);
|
||||
cb(cur, "block_out", il);
|
||||
|
||||
// input for next layer
|
||||
inpL = cur;
|
||||
}
|
||||
|
||||
cur = inpL;
|
||||
|
||||
cur = build_norm(cur,
|
||||
model.output_norm, NULL,
|
||||
LLM_NORM_RMS, -1);
|
||||
|
||||
cb(cur, "result_norm", -1);
|
||||
res->t_embd = cur;
|
||||
|
||||
// lm_head
|
||||
cur = build_lora_mm(model.output, cur);
|
||||
cb(cur, "result_output", -1);
|
||||
res->t_logits = cur;
|
||||
|
||||
ggml_build_forward_expand(gf, cur);
|
||||
}
|
||||
|
||||
ggml_tensor * build_attention_layer(
|
||||
ggml_tensor * cur,
|
||||
llm_graph_input_attn_kv * inp_attn,
|
||||
const llama_model & model,
|
||||
const int64_t n_embd_head,
|
||||
const int il) {
|
||||
|
||||
// compute Q and K and (optionally) RoPE them
|
||||
ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur);
|
||||
cb(Qcur, "Qcur", il);
|
||||
if (model.layers[il].bq) {
|
||||
Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq);
|
||||
cb(Qcur, "Qcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur);
|
||||
cb(Kcur, "Kcur", il);
|
||||
if (model.layers[il].bk) {
|
||||
Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk);
|
||||
cb(Kcur, "Kcur", il);
|
||||
}
|
||||
|
||||
ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur);
|
||||
cb(Vcur, "Vcur", il);
|
||||
if (model.layers[il].bv) {
|
||||
Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv);
|
||||
cb(Vcur, "Vcur", il);
|
||||
}
|
||||
|
||||
Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, hparams.n_head(il), n_tokens);
|
||||
Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
|
||||
Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, hparams.n_head_kv(il), n_tokens);
|
||||
|
||||
cb(Qcur, "Qcur", il);
|
||||
cb(Kcur, "Kcur", il);
|
||||
cb(Vcur, "Vcur", il);
|
||||
|
||||
const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f/sqrtf(float(n_embd_head)) : hparams.f_attention_scale;
|
||||
cur = build_attn(inp_attn,
|
||||
model.layers[il].wo, model.layers[il].bo,
|
||||
Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il);
|
||||
cb(cur, "attn_out", il);
|
||||
return cur;
|
||||
}
|
||||
|
||||
ggml_tensor * build_ffn_layer(
|
||||
ggml_tensor * cur,
|
||||
const llama_model & model,
|
||||
const int il) {
|
||||
|
||||
cur = build_ffn(cur,
|
||||
model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL,
|
||||
NULL, NULL, NULL,
|
||||
model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL,
|
||||
NULL,
|
||||
LLM_FFN_RELU_SQR, LLM_FFN_PAR, il);
|
||||
cb(cur, "ffn_out", il);
|
||||
|
||||
cur = build_cvec(cur, il);
|
||||
cb(cur, "l_out", il);
|
||||
|
||||
return cur;
|
||||
}
|
||||
};
|
||||
|
||||
struct llm_build_exaone : public llm_graph_context {
|
||||
llm_build_exaone(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) {
|
||||
const int64_t n_embd_head = hparams.n_embd_head_v;
|
||||
|
|
@ -18277,6 +18500,23 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
cparams.n_seq_max,
|
||||
nullptr);
|
||||
} else if (llm_arch_is_hybrid(arch)) {
|
||||
|
||||
// The main difference between hybrid architectures is the
|
||||
// layer filters, so pick the right one here
|
||||
llama_memory_hybrid::layer_filter_cb filter_attn = nullptr;
|
||||
llama_memory_hybrid::layer_filter_cb filter_recr = nullptr;
|
||||
if (arch == LLM_ARCH_FALCON_H1) {
|
||||
filter_attn = [&](int32_t) { return true; };
|
||||
filter_recr = [&](int32_t) { return true; };
|
||||
} else if (arch == LLM_ARCH_NEMOTRON_H) {
|
||||
filter_attn = [&](int32_t il) {
|
||||
return !hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
|
||||
};
|
||||
filter_recr = [&](int32_t il) {
|
||||
return hparams.is_recurrent(il) && hparams.n_ff(il) == 0;
|
||||
};
|
||||
}
|
||||
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
cparams.n_ctx = GGML_PAD(cparams.n_ctx, padding);
|
||||
|
|
@ -18296,8 +18536,8 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params,
|
|||
/* n_seq_max */ cparams.n_seq_max,
|
||||
/* offload */ cparams.offload_kqv,
|
||||
/* unified */ cparams.kv_unified,
|
||||
/* filter_attn */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr,
|
||||
/* filter_recr */ (arch == LLM_ARCH_FALCON_H1) ? [&](int32_t) { return true; } : (llama_memory_hybrid::layer_filter_cb)nullptr);
|
||||
/* filter_attn */ std::move(filter_attn),
|
||||
/* filter_recr */ std::move(filter_recr));
|
||||
} else {
|
||||
const auto padding = llama_kv_cache::get_padding(cparams);
|
||||
|
||||
|
|
@ -18625,6 +18865,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
|
|||
{
|
||||
llm = std::make_unique<llm_build_nemotron>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
{
|
||||
llm = std::make_unique<llm_build_nemotron_h>(*this, params);
|
||||
} break;
|
||||
case LLM_ARCH_EXAONE:
|
||||
{
|
||||
llm = std::make_unique<llm_build_exaone>(*this, params);
|
||||
|
|
@ -18860,6 +19104,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
|
|||
case LLM_ARCH_RWKV7:
|
||||
case LLM_ARCH_ARWKV7:
|
||||
case LLM_ARCH_WAVTOKENIZER_DEC:
|
||||
case LLM_ARCH_NEMOTRON_H:
|
||||
return LLAMA_ROPE_TYPE_NONE;
|
||||
|
||||
// use what we call a normal RoPE, operating on pairs of consecutive head values
|
||||
|
|
|
|||
Loading…
Reference in New Issue