Updated copilot concerns and rebased

This commit is contained in:
Michael Wand 2026-03-17 10:42:39 -07:00
parent ddab59e9c8
commit 84c04f0a01
1 changed files with 6 additions and 7 deletions

View File

@ -616,7 +616,7 @@ class ModelBase:
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
def _transform_nvfp4_weight(self, raw_weight_name: str, weight: Tensor, scale: Tensor, bid: int | None) -> tuple[str, Tensor, Tensor]:
def _transform_nvfp4_weight(self, raw_weight_name: str, weight: Tensor, scale: Tensor) -> tuple[str, Tensor, Tensor]:
if not isinstance(self, (Qwen3_5TextModel, Qwen3_5MoeTextModel)) or not raw_weight_name.endswith((
".linear_attn.in_proj_qkv.weight",
".linear_attn.in_proj_z.weight",
@ -703,8 +703,9 @@ class ModelBase:
1, num_k_heads, num_v_per_k, head_v_dim,
).squeeze(0)
transformed_components = apply_col_perm(weight, scale, col_perm)
if transformed_components is not None:
weight, scale = transformed_components
if transformed_components is None:
raise ValueError(f"Can not apply NVFP4 Quwen3.5 permutation for tensor {raw_weight_name!r}")
weight, scale = transformed_components
return new_name, weight, scale
@ -773,9 +774,7 @@ class ModelBase:
if n_experts > 0 and len(expert_blocks[key]) >= n_experts:
self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_shapes, bid, proj_type)
else:
bid_m = re.search(r'\.layers\.(\d+)\.', name)
bid = int(bid_m.group(1)) if bid_m else None
new_name, weight, scale = self._transform_nvfp4_weight(name, weight, scale, bid)
new_name, weight, scale = self._transform_nvfp4_weight(name, weight, scale)
self._repack_nvfp4(new_name, weight, scale, scale2)
# Flush any remaining experts (fallback if n_experts was unknown)
@ -855,7 +854,7 @@ class ModelBase:
if self._is_nvfp4:
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
continue
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale")):
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
continue
old_dtype = data_torch.dtype