Convert: Fix NemotronH Config Parsing (#21664)

* fix NemotronH vocab loading by using trust_remote_code for unsupported config patterns

* fix NemotronH tokenizer loading by overriding set_vocab with trust_remote_code
This commit is contained in:
Anav Prasad 2026-04-16 10:11:45 +00:00 committed by GitHub
parent 3f7c29d318
commit 03b3d07798
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 58 additions and 1 deletions

View File

@ -10893,7 +10893,64 @@ class NemotronHModel(GraniteHybridModel):
self.gguf_writer.add_moe_latent_size(latent_size)
def set_vocab(self):
super().set_vocab()
# The NemotronH config uses pattern characters (e.g. '-') that may not
# be supported by the installed transformers version. AutoTokenizer
# internally calls AutoConfig which triggers this parsing failure.
# Using trust_remote_code=True to load the model's own config class.
tokens: list[str] = []
toktypes: list[int] = []
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model, trust_remote_code=True)
# Pad vocab size (from Mamba2Model/GraniteHybridModel)
self.hparams["pad_vocab_size_multiple"] = 8 # Setting this here since GraniteHybridModel.set_vocab() isn't being invoked now.
# From Mamba2Model.set_vocab():
vocab_size = self.hparams["vocab_size"]
pad_vocab = self.hparams.get("pad_vocab_size_multiple", 16)
# ref: https://stackoverflow.com/a/17511341/22827863
vocab_size = -(vocab_size // -pad_vocab) * pad_vocab
self.hparams["vocab_size"] = vocab_size
assert max(tokenizer.vocab.values()) < vocab_size
tokpre = self.get_vocab_base_pre(tokenizer)
reverse_vocab = {id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()}
added_vocab = tokenizer.get_added_vocab()
added_tokens_decoder = tokenizer.added_tokens_decoder
for i in range(vocab_size):
if i not in reverse_vocab:
tokens.append(f"[PAD{i}]")
toktypes.append(gguf.TokenType.UNUSED)
else:
token: str = reverse_vocab[i]
if token in added_vocab:
if not added_tokens_decoder[i].normalized:
previous_token = token
token = tokenizer.decode(tokenizer.encode(token, add_special_tokens=False))
if previous_token != token:
logger.info(f"{repr(previous_token)} is encoded and decoded back to {repr(token)} using AutoTokenizer")
if added_tokens_decoder[i].special or self.does_token_look_special(token):
toktypes.append(gguf.TokenType.CONTROL)
else:
token = token.replace(b"\xe2\x96\x81".decode("utf-8"), " ") # pre-normalize user-defined spaces
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
toktypes.append(gguf.TokenType.NORMAL)
tokens.append(token)
# From TextModel.set_vocab_gpt2():
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
# The tokenizer _does_ add a BOS token (via post_processor type
# TemplateProcessing) but does not set add_bos_token to true in the