convert : fix big-endian conversion (#17431)

* Fix convert_hf_to_gguf.py script on s390x

Assume converted model data is originally little-endian.
Byteswap data on s390x after reading it to put values in correct presentation
for any transformation needed, like calculating weight tensors.

Then byteswap data to little-endian before passing it to GGUFWriter while
GGUFWriter will byteswap data back to big endian if big endian output is requested.

byteswap(inplace=True) calls don't work with lazy tensor and array wrappers.
Use byteswap with copying data to workaround this behaviour.

* Make GGUFWriter accept tensors in native endianness instead of little-endian

With this change if no byteswapping is actually needed, 2 excessive byteswaps can be omitted on s390x

* Fix byteswapping in convert_hf_to_gguf.py for remote models
This commit is contained in:
Aleksei Nikiforov 2025-11-25 14:18:16 +01:00 committed by GitHub
parent 55ab25caf5
commit 05872ac885
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 42 additions and 6 deletions

View File

@ -10061,6 +10061,25 @@ class LazyTorchTensor(gguf.LazyBase):
torch.uint8: np.uint8, torch.uint8: np.uint8,
} }
# only used when byteswapping data. Only correct size is needed
_dtype_byteswap_map: dict[torch.dtype, type] = {
torch.float64: np.float64,
torch.float32: np.float32,
torch.bfloat16: np.float16,
torch.float16: np.float16,
torch.int64: np.int64,
torch.uint64: np.uint64,
torch.int32: np.int32,
torch.uint32: np.uint32,
torch.int16: np.int16,
torch.uint16: np.uint16,
torch.int8: np.int8,
torch.uint8: np.uint8,
torch.bool: np.uint8,
torch.float8_e4m3fn: np.uint8,
torch.float8_e5m2: np.uint8,
}
# used for safetensors slices # used for safetensors slices
# ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046 # ref: https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/src/lib.rs#L1046
# TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734 # TODO: uncomment U64, U32, and U16, ref: https://github.com/pytorch/pytorch/issues/58734
@ -10104,8 +10123,14 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod @classmethod
def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor: def from_local_tensor(cls, t: gguf.utility.LocalTensor) -> Tensor:
def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor: def load_tensor(tensor: gguf.utility.LocalTensor) -> Tensor:
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[tensor.dtype] dtype = cls._dtype_str_map[tensor.dtype]
return torch.from_numpy(tensor.mmap_bytes()).view(dtype).reshape(tensor.shape) numpy_dtype = cls._dtype_byteswap_map[dtype]
return torch.from_numpy(byteswap_tensor(tensor.mmap_bytes(), numpy_dtype)).view(dtype).reshape(tensor.shape)
dtype = cls._dtype_str_map[t.dtype] dtype = cls._dtype_str_map[t.dtype]
shape = t.shape shape = t.shape
lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r)) lazy = cls(meta=cls.meta_with_dtype_and_shape(dtype, shape), args=(t,), func=lambda r: load_tensor(r))
@ -10113,10 +10138,16 @@ class LazyTorchTensor(gguf.LazyBase):
@classmethod @classmethod
def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor): def from_remote_tensor(cls, remote_tensor: gguf.utility.RemoteTensor):
def byteswap_tensor(tensor: np.ndarray, dtype: type) -> np.ndarray:
if sys.byteorder == 'big':
# switch data back to big endian
tensor = tensor.view(dtype).byteswap(inplace=False)
return tensor
dtype = cls._dtype_str_map[remote_tensor.dtype] dtype = cls._dtype_str_map[remote_tensor.dtype]
numpy_dtype = cls._dtype_byteswap_map[dtype]
shape = remote_tensor.shape shape = remote_tensor.shape
meta = cls.meta_with_dtype_and_shape(dtype, shape) meta = cls.meta_with_dtype_and_shape(dtype, shape)
lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.frombuffer(r.data(), dtype=dtype).reshape(shape)) lazy = cls(meta=meta, args=(remote_tensor,), func=lambda r: torch.from_numpy(byteswap_tensor(np.frombuffer(r.data(), dtype=numpy_dtype), numpy_dtype)).view(dtype).reshape(shape))
return cast(torch.Tensor, lazy) return cast(torch.Tensor, lazy)
@classmethod @classmethod

View File

@ -4,6 +4,7 @@ import logging
import os import os
import shutil import shutil
import struct import struct
import sys
import tempfile import tempfile
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum, auto from enum import Enum, auto
@ -372,8 +373,10 @@ class GGUFWriter:
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None,
) -> None: ) -> None:
if self.endianess == GGUFEndian.BIG: if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
tensor.byteswap(inplace=True) (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None: if self.use_temp_file and self.temp_file is None:
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024) fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256 * 1024 * 1024)
fp.seek(0) fp.seek(0)
@ -399,8 +402,10 @@ class GGUFWriter:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None assert self.fout is not None
if self.endianess == GGUFEndian.BIG: if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
tensor.byteswap(inplace=True) (self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
file_id = -1 file_id = -1
for i, tensors in enumerate(self.tensors): for i, tensors in enumerate(self.tensors):