convert: allow using quantized Mistral weight (#17889)
* convert: allow using quantized Mistral weight * data_torch.ndim * update dequant fn Co-authored-by: compilade <compilade@users.noreply.github.com> --------- Co-authored-by: compilade <compilade@users.noreply.github.com>
This commit is contained in:
parent
2e9eab80c2
commit
9e79b0116e
|
|
@ -383,6 +383,17 @@ class ModelBase:
|
||||||
s = self.model_tensors[name]
|
s = self.model_tensors[name]
|
||||||
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
|
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
|
||||||
tensors_to_remove.append(name)
|
tensors_to_remove.append(name)
|
||||||
|
if name.endswith(".activation_scale"): # unused
|
||||||
|
tensors_to_remove.append(name)
|
||||||
|
# mistral format
|
||||||
|
if name.endswith(".qscale_weight"):
|
||||||
|
weight_name = name.removesuffix("qscale_weight") + "weight"
|
||||||
|
w = self.model_tensors[weight_name]
|
||||||
|
s = self.model_tensors[name]
|
||||||
|
self.model_tensors[weight_name] = lambda w=w, s=s, bs=block_size: dequant_simple(w(), s(), bs)
|
||||||
|
tensors_to_remove.append(name)
|
||||||
|
if name.endswith(".qscale_act"):
|
||||||
|
tensors_to_remove.append(name)
|
||||||
elif quant_method == "gptq":
|
elif quant_method == "gptq":
|
||||||
for name in self.model_tensors.keys():
|
for name in self.model_tensors.keys():
|
||||||
if name.endswith(".qweight"):
|
if name.endswith(".qweight"):
|
||||||
|
|
@ -2854,13 +2865,10 @@ class Mistral3Model(LlamaModel):
|
||||||
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
self.gguf_writer.add_attn_temperature_scale(rope_params["llama_4_scaling_beta"])
|
||||||
|
|
||||||
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
|
||||||
# TODO: probably not worth supporting quantized weight, as official BF16 is also available
|
|
||||||
if name.endswith("weight_scale_inv"):
|
|
||||||
raise ValueError("This is a quantized weight, please use BF16 weight instead")
|
|
||||||
|
|
||||||
name = name.replace("language_model.", "")
|
name = name.replace("language_model.", "")
|
||||||
if "multi_modal_projector" in name or "vision_tower" in name:
|
if "multi_modal_projector" in name or "vision_tower" in name:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return super().modify_tensors(data_torch, name, bid)
|
return super().modify_tensors(data_torch, name, bid)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -9898,6 +9906,18 @@ class MistralModel(LlamaModel):
|
||||||
self.gguf_writer.add_architecture()
|
self.gguf_writer.add_architecture()
|
||||||
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
|
||||||
|
|
||||||
|
def dequant_model(self):
|
||||||
|
# transform quantization config into HF format
|
||||||
|
quant_config = self.hparams.get("quantization")
|
||||||
|
if quant_config is not None:
|
||||||
|
assert quant_config["qformat_weight"] == "fp8_e4m3"
|
||||||
|
self.hparams["quantization_config"] = {
|
||||||
|
"activation_scheme": "static",
|
||||||
|
"quant_method": "fp8",
|
||||||
|
"weight_block_size": None,
|
||||||
|
}
|
||||||
|
return super().dequant_model()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
def get_community_chat_template(vocab: MistralVocab, templates_dir: Path, is_mistral_format: bool):
|
||||||
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
assert TokenizerVersion is not None and Tekkenizer is not None and SentencePieceTokenizer is not None, _mistral_import_error_msg
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue