Updated copilot concerns and rebased

This commit is contained in:
Michael Wand 2026-03-17 10:42:39 -07:00
parent ddab59e9c8
commit 84c04f0a01
1 changed files with 6 additions and 7 deletions

View File

@ -616,7 +616,7 @@ class ModelBase:
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool: def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6 return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
def _transform_nvfp4_weight(self, raw_weight_name: str, weight: Tensor, scale: Tensor, bid: int | None) -> tuple[str, Tensor, Tensor]: def _transform_nvfp4_weight(self, raw_weight_name: str, weight: Tensor, scale: Tensor) -> tuple[str, Tensor, Tensor]:
if not isinstance(self, (Qwen3_5TextModel, Qwen3_5MoeTextModel)) or not raw_weight_name.endswith(( if not isinstance(self, (Qwen3_5TextModel, Qwen3_5MoeTextModel)) or not raw_weight_name.endswith((
".linear_attn.in_proj_qkv.weight", ".linear_attn.in_proj_qkv.weight",
".linear_attn.in_proj_z.weight", ".linear_attn.in_proj_z.weight",
@ -703,8 +703,9 @@ class ModelBase:
1, num_k_heads, num_v_per_k, head_v_dim, 1, num_k_heads, num_v_per_k, head_v_dim,
).squeeze(0) ).squeeze(0)
transformed_components = apply_col_perm(weight, scale, col_perm) transformed_components = apply_col_perm(weight, scale, col_perm)
if transformed_components is not None: if transformed_components is None:
weight, scale = transformed_components raise ValueError(f"Can not apply NVFP4 Quwen3.5 permutation for tensor {raw_weight_name!r}")
weight, scale = transformed_components
return new_name, weight, scale return new_name, weight, scale
@ -773,9 +774,7 @@ class ModelBase:
if n_experts > 0 and len(expert_blocks[key]) >= n_experts: if n_experts > 0 and len(expert_blocks[key]) >= n_experts:
self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_shapes, bid, proj_type) self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_shapes, bid, proj_type)
else: else:
bid_m = re.search(r'\.layers\.(\d+)\.', name) new_name, weight, scale = self._transform_nvfp4_weight(name, weight, scale)
bid = int(bid_m.group(1)) if bid_m else None
new_name, weight, scale = self._transform_nvfp4_weight(name, weight, scale, bid)
self._repack_nvfp4(new_name, weight, scale, scale2) self._repack_nvfp4(new_name, weight, scale, scale2)
# Flush any remaining experts (fallback if n_experts was unknown) # Flush any remaining experts (fallback if n_experts was unknown)
@ -855,7 +854,7 @@ class ModelBase:
if self._is_nvfp4: if self._is_nvfp4:
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors: if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
continue continue
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale")): if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
continue continue
old_dtype = data_torch.dtype old_dtype = data_torch.dtype