Merge be3c4bbe99 into 9e2e2198b0
This commit is contained in:
commit
789120b6b9
|
|
@ -511,22 +511,26 @@ class ModelBase:
|
|||
return name == (key_name + suffix)
|
||||
|
||||
def map_tensor_name(self, name: str, try_suffixes: Sequence[str] = (".weight", ".bias")) -> str:
|
||||
new_name = self.tensor_map.get_name(key=name, try_suffixes=try_suffixes)
|
||||
if new_name is None:
|
||||
raise ValueError(f"Can not map tensor {name!r}")
|
||||
return new_name
|
||||
names_to_try = [name]
|
||||
|
||||
if name.startswith("model.language_model."):
|
||||
stripped = name.replace("model.language_model.", "", 1)
|
||||
names_to_try.extend((f"model.{stripped}", stripped))
|
||||
elif name.startswith("language_model."):
|
||||
stripped = name.replace("language_model.", "", 1)
|
||||
names_to_try.extend((stripped, f"model.{stripped}"))
|
||||
|
||||
for candidate in names_to_try:
|
||||
new_name = self.tensor_map.get_name(key=candidate, try_suffixes=try_suffixes)
|
||||
if new_name is not None:
|
||||
return new_name
|
||||
|
||||
raise ValueError(f"Can not map tensor {name!r}")
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses")
|
||||
|
||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
|
||||
# skip NVFP4 auxiliary tensors (handled in _generate_nvfp4_tensors)
|
||||
if self._is_nvfp4:
|
||||
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
|
||||
return []
|
||||
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
|
||||
return []
|
||||
|
||||
new_name = self.map_tensor_name(name)
|
||||
|
||||
# Handle gate/up expert tensor fusion if enabled
|
||||
|
|
@ -591,6 +595,98 @@ class ModelBase:
|
|||
def _nvfp4_scale2_is_trivial(scale2: Tensor) -> bool:
|
||||
return scale2.numel() <= 1 and abs(float(scale2.float().sum()) - 1.0) < 1e-6
|
||||
|
||||
def _transform_nvfp4_weight(self, raw_weight_name: str, weight: Tensor, scale: Tensor, bid: int | None) -> tuple[str, Tensor, Tensor]:
|
||||
if not isinstance(self, (Qwen3_5TextModel, Qwen3_5MoeTextModel)) or not raw_weight_name.endswith((
|
||||
".linear_attn.in_proj_qkv.weight",
|
||||
".linear_attn.in_proj_z.weight",
|
||||
".linear_attn.in_proj_a.weight",
|
||||
".linear_attn.in_proj_b.weight",
|
||||
".linear_attn.out_proj.weight",
|
||||
)):
|
||||
return self.map_tensor_name(raw_weight_name), weight, scale
|
||||
|
||||
num_k_heads = self.hparams["linear_num_key_heads"]
|
||||
num_v_heads = self.hparams["linear_num_value_heads"]
|
||||
head_k_dim = self.hparams["linear_key_head_dim"]
|
||||
head_v_dim = self.hparams["linear_value_head_dim"]
|
||||
num_v_per_k = num_v_heads // num_k_heads
|
||||
new_name = self.map_tensor_name(raw_weight_name)
|
||||
|
||||
def unpack_nibbles(qs: Tensor) -> Tensor:
|
||||
lo = torch.bitwise_and(qs, 0x0F)
|
||||
hi = torch.bitwise_right_shift(qs, 4)
|
||||
return torch.stack((lo, hi), dim=-1).reshape(*qs.shape[:-1], qs.shape[-1] * 2)
|
||||
|
||||
def pack_nibbles(codes: Tensor) -> Tensor:
|
||||
codes = codes.reshape(*codes.shape[:-1], codes.shape[-1] // 2, 2)
|
||||
lo = torch.bitwise_and(codes[..., 0], 0x0F)
|
||||
hi = torch.bitwise_left_shift(torch.bitwise_and(codes[..., 1], 0x0F), 4)
|
||||
return torch.bitwise_or(lo, hi).contiguous()
|
||||
|
||||
def apply_col_perm(qs: Tensor, scales: Tensor, col_perm: Tensor) -> tuple[Tensor, Tensor] | None:
|
||||
if qs.ndim < 2 or scales.ndim < 2:
|
||||
return None
|
||||
|
||||
k = qs.shape[-1] * 2
|
||||
if col_perm.numel() != k or k % 16 != 0:
|
||||
return None
|
||||
|
||||
group_cols = col_perm.reshape(-1, 16)
|
||||
group_starts = group_cols[:, 0]
|
||||
expected = group_starts.unsqueeze(1) + torch.arange(16, dtype=col_perm.dtype)
|
||||
if not torch.equal(group_cols, expected):
|
||||
return None
|
||||
if torch.any(group_starts % 16 != 0):
|
||||
return None
|
||||
|
||||
group_perm = (group_starts // 16).to(dtype=torch.long)
|
||||
expected_groups = torch.arange(scales.shape[-1], dtype=torch.long)
|
||||
if group_perm.numel() != scales.shape[-1] or not torch.equal(torch.sort(group_perm).values, expected_groups):
|
||||
return None
|
||||
|
||||
codes = unpack_nibbles(qs)
|
||||
codes = codes.index_select(-1, col_perm.to(device=qs.device, dtype=torch.long))
|
||||
qs = pack_nibbles(codes)
|
||||
scales = scales.index_select(-1, group_perm.to(device=scales.device))
|
||||
return qs, scales
|
||||
|
||||
def reorder_rows(qs: Tensor, scales: Tensor, head_dim: int) -> tuple[Tensor, Tensor]:
|
||||
row_perm = _LinearAttentionVReorderBase._reorder_v_heads(
|
||||
torch.arange(num_v_heads * head_dim, dtype=torch.long).unsqueeze(-1),
|
||||
0, num_k_heads, num_v_per_k, head_dim,
|
||||
).squeeze(-1)
|
||||
return (
|
||||
qs.index_select(0, row_perm.to(device=qs.device)),
|
||||
scales.index_select(0, row_perm.to(device=scales.device)),
|
||||
)
|
||||
|
||||
if raw_weight_name.endswith(".linear_attn.in_proj_qkv.weight"):
|
||||
q_dim = head_k_dim * num_k_heads
|
||||
k_dim = head_k_dim * num_k_heads
|
||||
q = weight[:q_dim]
|
||||
k = weight[q_dim:q_dim + k_dim]
|
||||
v = weight[q_dim + k_dim:]
|
||||
q_scale = scale[:q_dim]
|
||||
k_scale = scale[q_dim:q_dim + k_dim]
|
||||
v_scale = scale[q_dim + k_dim:]
|
||||
v, v_scale = reorder_rows(v, v_scale, head_v_dim)
|
||||
return new_name, torch.cat([q, k, v], dim=0), torch.cat([q_scale, k_scale, v_scale], dim=0)
|
||||
|
||||
if raw_weight_name.endswith(".linear_attn.in_proj_z.weight"):
|
||||
weight, scale = reorder_rows(weight, scale, head_v_dim)
|
||||
elif raw_weight_name.endswith((".linear_attn.in_proj_a.weight", ".linear_attn.in_proj_b.weight")):
|
||||
weight, scale = reorder_rows(weight, scale, 1)
|
||||
elif raw_weight_name.endswith(".linear_attn.out_proj.weight"):
|
||||
col_perm = _LinearAttentionVReorderBase._reorder_v_heads(
|
||||
torch.arange(num_v_heads * head_v_dim, dtype=torch.long).unsqueeze(0),
|
||||
1, num_k_heads, num_v_per_k, head_v_dim,
|
||||
).squeeze(0)
|
||||
transformed_components = apply_col_perm(weight, scale, col_perm)
|
||||
if transformed_components is not None:
|
||||
weight, scale = transformed_components
|
||||
|
||||
return new_name, weight, scale
|
||||
|
||||
def _repack_nvfp4(self, new_name: str, weight: Tensor, scale: Tensor, scale2: Tensor):
|
||||
raw, shape = self._nvfp4_pack(weight, scale)
|
||||
logger.info(f"Repacked {new_name} with shape {shape} and quantization NVFP4")
|
||||
|
|
@ -645,7 +741,9 @@ class ModelBase:
|
|||
if n_experts > 0 and len(expert_blocks[key]) >= n_experts:
|
||||
self._flush_nvfp4_experts(key, expert_blocks, expert_scales, expert_shapes, bid, proj_type)
|
||||
else:
|
||||
new_name = self.map_tensor_name(name)
|
||||
bid_m = re.search(r'\.layers\.(\d+)\.', name)
|
||||
bid = int(bid_m.group(1)) if bid_m else None
|
||||
new_name, weight, scale = self._transform_nvfp4_weight(name, weight, scale, bid)
|
||||
self._repack_nvfp4(new_name, weight, scale, scale2)
|
||||
|
||||
# Flush any remaining experts (fallback if n_experts was unknown)
|
||||
|
|
@ -702,6 +800,12 @@ class ModelBase:
|
|||
if name.endswith((".attention.masked_bias", ".attention.bias", ".rotary_emb.inv_freq")):
|
||||
continue
|
||||
|
||||
if self._is_nvfp4:
|
||||
if name.endswith(".weight") and name.replace(".weight", ".weight_scale") in self.model_tensors:
|
||||
continue
|
||||
if name.endswith((".weight_scale", ".weight_scale_2", ".input_scale", ".k_scale", ".v_scale")):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
|
|
|
|||
Loading…
Reference in New Issue