convert : avoid dequantizing mxfp4 for GPT-OSS (#16756)
This commit is contained in:
parent
55945d2ef5
commit
5cca2542ac
|
|
@ -8943,6 +8943,13 @@ class SmolLM3Model(LlamaModel):
|
||||||
class GptOssModel(TextModel):
|
class GptOssModel(TextModel):
|
||||||
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
model_arch = gguf.MODEL_ARCH.GPT_OSS
|
||||||
|
|
||||||
|
# TODO: remove once MXFP4 is supported more generally
|
||||||
|
def dequant_model(self):
|
||||||
|
quant_config = self.hparams.get("quantization_config")
|
||||||
|
if quant_config is not None and quant_config.get("quant_method") == "mxfp4":
|
||||||
|
return
|
||||||
|
return super().dequant_model()
|
||||||
|
|
||||||
def transform_nibble_layout(self, tensor):
|
def transform_nibble_layout(self, tensor):
|
||||||
assert tensor.dtype == torch.uint8
|
assert tensor.dtype == torch.uint8
|
||||||
assert tensor.shape[-1] == 16
|
assert tensor.shape[-1] == 16
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue