convert : support mixed-precision ModelOpt models with per-tensor NVFP4/FP8 quantization (#20539)
* support mixed-precision ModelOpt models with per-tensor NVFP4/FP8 quantization * cleanup * fallback --------- Co-authored-by: Sigbjørn Skjæret <sigbjorn.skjaeret@scala.com>
This commit is contained in:
parent
d3936498a3
commit
079e5a45f0
|
|
@ -272,8 +272,9 @@ class ModelBase:
|
||||||
return tensors
|
return tensors
|
||||||
|
|
||||||
def dequant_model(self):
|
def dequant_model(self):
|
||||||
if self._is_nvfp4:
|
# If all quantized tensors were already handled (e.g. pure NVFP4), skip
|
||||||
return # NVFP4 weights are repacked in _generate_nvfp4_tensors
|
if self._is_nvfp4 and not any(k.endswith((".weight_scale", ".weight_scale_inv")) for k in self.model_tensors):
|
||||||
|
return
|
||||||
|
|
||||||
tensors_to_remove: list[str] = []
|
tensors_to_remove: list[str] = []
|
||||||
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
new_tensors: dict[str, Callable[[], Tensor]] = {}
|
||||||
|
|
@ -474,7 +475,20 @@ class ModelBase:
|
||||||
tensors_to_remove.append(base_name + "_zero_point")
|
tensors_to_remove.append(base_name + "_zero_point")
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
|
raise NotImplementedError(f"Quant format {quant_format!r} for method {quant_method!r} is not yet supported")
|
||||||
else:
|
elif quant_method == "modelopt":
|
||||||
|
# Mixed-precision ModelOpt models: NVFP4 tensors are handled by
|
||||||
|
# _generate_nvfp4_tensors; FP8 tensors have 1D weight_scale and
|
||||||
|
# are dequantized here. input_scale tensors are unused.
|
||||||
|
for name in self.model_tensors.keys():
|
||||||
|
if name.endswith(".weight_scale"):
|
||||||
|
weight_name = name.removesuffix("_scale")
|
||||||
|
w = self.model_tensors[weight_name]
|
||||||
|
s = self.model_tensors[name]
|
||||||
|
self.model_tensors[weight_name] = lambda w=w, s=s: dequant_simple(w(), s(), None)
|
||||||
|
tensors_to_remove.append(name)
|
||||||
|
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||||
|
tensors_to_remove.append(name)
|
||||||
|
elif quant_method is not None:
|
||||||
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
raise NotImplementedError(f"Quant method is not yet supported: {quant_method!r}")
|
||||||
|
|
||||||
for name in tensors_to_remove:
|
for name in tensors_to_remove:
|
||||||
|
|
@ -520,12 +534,6 @@ class ModelBase:
|
||||||
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||||
# skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
|
|
||||||
if self._is_nvfp4:
|
|
||||||
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
|
|
||||||
return []
|
|
||||||
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
|
|
||||||
return []
|
|
||||||
|
|
||||||
new_name = self.map_tensor_name(name)
|
new_name = self.map_tensor_name(name)
|
||||||
|
|
||||||
|
|
@ -609,6 +617,7 @@ class ModelBase:
|
||||||
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
|
expert_scales: dict[tuple[int, str], list[tuple[int, float]]] = {}
|
||||||
expert_shapes: dict[tuple[int, str], list[int]] = {}
|
expert_shapes: dict[tuple[int, str], list[int]] = {}
|
||||||
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
|
n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=True) or 0
|
||||||
|
consumed: list[str] = []
|
||||||
|
|
||||||
for name in list(self.model_tensors.keys()):
|
for name in list(self.model_tensors.keys()):
|
||||||
if not name.endswith(".weight"):
|
if not name.endswith(".weight"):
|
||||||
|
|
@ -620,8 +629,18 @@ class ModelBase:
|
||||||
# Force eager materialization of lazy tensors
|
# Force eager materialization of lazy tensors
|
||||||
weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
|
weight = LazyTorchTensor.to_eager(self.model_tensors[name]())
|
||||||
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
|
scale = LazyTorchTensor.to_eager(self.model_tensors[scale_name]())
|
||||||
|
|
||||||
|
# Skip non-NVFP4 tensors (e.g. FP8 with per-channel 1D scales)
|
||||||
|
if scale.ndim < 2:
|
||||||
|
continue
|
||||||
|
|
||||||
scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
|
scale2 = LazyTorchTensor.to_eager(self.model_tensors.get(scale2_name, lambda: torch.tensor(1.0))())
|
||||||
|
|
||||||
|
# Mark tensors for removal from model_tensors (already written to gguf)
|
||||||
|
consumed.extend([name, scale_name])
|
||||||
|
if scale2_name in self.model_tensors:
|
||||||
|
consumed.append(scale2_name)
|
||||||
|
|
||||||
# Check if this is a per-expert tensor
|
# Check if this is a per-expert tensor
|
||||||
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
|
m = re.search(r'\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.weight$', name)
|
||||||
if m:
|
if m:
|
||||||
|
|
@ -652,6 +671,15 @@ class ModelBase:
|
||||||
for (bid, proj_type) in list(expert_blocks.keys()):
|
for (bid, proj_type) in list(expert_blocks.keys()):
|
||||||
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
|
self._flush_nvfp4_experts((bid, proj_type), expert_blocks, expert_scales, expert_shapes, bid, proj_type)
|
||||||
|
|
||||||
|
# Remove consumed tensors so get_tensors/modify_tensors won't see them
|
||||||
|
for name in consumed:
|
||||||
|
self.model_tensors.pop(name, None)
|
||||||
|
|
||||||
|
# Remove unused auxiliary tensors (input_scale, k_scale, v_scale)
|
||||||
|
for name in list(self.model_tensors.keys()):
|
||||||
|
if name.endswith((".input_scale", ".k_scale", ".v_scale")):
|
||||||
|
del self.model_tensors[name]
|
||||||
|
|
||||||
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
|
def _flush_nvfp4_experts(self, key, expert_blocks, expert_scales, expert_shapes, bid, proj_type):
|
||||||
experts = expert_blocks.pop(key)
|
experts = expert_blocks.pop(key)
|
||||||
scales = expert_scales.pop(key)
|
scales = expert_scales.pop(key)
|
||||||
|
|
@ -677,20 +705,31 @@ class ModelBase:
|
||||||
def prepare_tensors(self):
|
def prepare_tensors(self):
|
||||||
# detect NVFP4 quantization (ModelOpt format)
|
# detect NVFP4 quantization (ModelOpt format)
|
||||||
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
quant_algo = (self.hparams.get("quantization_config") or {}).get("quant_algo")
|
||||||
|
quant_layers = (self.hparams.get("quantization_config") or {}).get("quantized_layers") or {}
|
||||||
quant_config_file = self.dir_model / "hf_quant_config.json"
|
quant_config_file = self.dir_model / "hf_quant_config.json"
|
||||||
|
|
||||||
if not quant_algo and quant_config_file.is_file():
|
if (not quant_algo or not quant_layers) and quant_config_file.is_file():
|
||||||
with open(quant_config_file, "r", encoding="utf-8") as f:
|
with open(quant_config_file, "r", encoding="utf-8") as f:
|
||||||
quant_algo = (json.load(f).get("quantization") or {}).get("quant_algo")
|
quant_config = json.load(f).get("quantization") or {}
|
||||||
|
quant_algo = quant_config.get("quant_algo", quant_algo)
|
||||||
|
quant_layers = quant_config.get("quantized_layers", quant_layers) or {}
|
||||||
|
|
||||||
|
# Some models use per-tensor quant_algo (e.g. "MIXED_PRECISION" with
|
||||||
|
# per-layer NVFP4/FP8) instead of a single global "NVFP4" value.
|
||||||
|
if quant_algo != "NVFP4":
|
||||||
|
if any(v.get("quant_algo") == "NVFP4" for v in quant_layers.values() if isinstance(v, dict)):
|
||||||
|
quant_algo = "NVFP4"
|
||||||
|
|
||||||
self._is_nvfp4 = quant_algo == "NVFP4"
|
self._is_nvfp4 = quant_algo == "NVFP4"
|
||||||
|
|
||||||
self.dequant_model()
|
# NVFP4 weights are repacked and written directly to gguf_writer.
|
||||||
|
# This must run before dequant_model so NVFP4 tensors are removed
|
||||||
# NVFP4 weights are repacked and written directly to gguf_writer
|
# from model_tensors, leaving only non-NVFP4 (e.g. FP8) for dequant.
|
||||||
if self._is_nvfp4:
|
if self._is_nvfp4:
|
||||||
self._generate_nvfp4_tensors()
|
self._generate_nvfp4_tensors()
|
||||||
|
|
||||||
|
self.dequant_model()
|
||||||
|
|
||||||
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
|
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
|
||||||
if self.tensor_map.mapping:
|
if self.tensor_map.mapping:
|
||||||
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue