150 lines
4.8 KiB
Python
150 lines
4.8 KiB
Python
import os
|
|
import requests
|
|
import argparse
|
|
import safetensors
|
|
import torch
|
|
import sys
|
|
|
|
# List of checkpoint dictionary keys to skip during merging
|
|
keys_to_skip_on_merge = [
|
|
"cond_stage_model.transformer.text_model.embeddings.position_ids"
|
|
]
|
|
|
|
# Dictionary for replacing keys in checkpoint dictionary
|
|
key_replacements = {
|
|
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.",
|
|
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.",
|
|
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.",
|
|
}
|
|
|
|
|
|
def transform_checkpoint_key(key):
|
|
for text, replacement in key_replacements.items():
|
|
if key.startswith(text):
|
|
key = replacement + key[len(text) :]
|
|
return key
|
|
|
|
|
|
def get_state_dict_from_checkpoint(ckpt_dict):
|
|
ckpt_dict = ckpt_dict.pop("state_dict", ckpt_dict)
|
|
|
|
return {
|
|
transform_checkpoint_key(k): v
|
|
for k, v in ckpt_dict.items()
|
|
if transform_checkpoint_key(k) is not None
|
|
}
|
|
|
|
|
|
def download_model(url):
|
|
filepath = f"/tmp/{os.path.basename(url)}"
|
|
if os.path.isfile(filepath):
|
|
return filepath
|
|
|
|
response = requests.get(url, stream=True)
|
|
total_size = int(response.headers.get("content-length", 0))
|
|
block_size = 1048576 # 1 MB
|
|
downloaded_size = 0
|
|
|
|
with open(filepath, "wb") as file:
|
|
for data in response.iter_content(block_size):
|
|
downloaded_size += len(data)
|
|
file.write(data)
|
|
# Calculate the progress
|
|
progress = downloaded_size / total_size * 100
|
|
print(f"Download progress: {progress:.2f}%")
|
|
return filepath
|
|
|
|
|
|
def load_model(checkpoint_file):
|
|
if "http" in checkpoint_file:
|
|
filepath = download_model(checkpoint_file)
|
|
else:
|
|
filepath = checkpoint_file
|
|
if not filepath:
|
|
raise ValueError(f"empty filepath for {checkpoint_file}")
|
|
|
|
_, extension = os.path.splitext(filepath)
|
|
if extension.lower() == ".safetensors":
|
|
model = safetensors.torch.load_file(filepath, device="cpu")
|
|
else:
|
|
model = torch.load(filepath, map_location="cpu")
|
|
|
|
return get_state_dict_from_checkpoint(model)
|
|
|
|
|
|
def generate(args):
|
|
input_model = args.model
|
|
inpainting_model_url = "https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt"
|
|
inpainting_model_name = os.path.basename(inpainting_model_url)
|
|
base_model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
|
|
base_model_name = os.path.basename(base_model_url)
|
|
output_model = f"{input_model.split('.')[0]}-inpainting.{input_model.split('.')[1]}"
|
|
|
|
print(f"Loading {input_model}")
|
|
input_state_dict = load_model(input_model)
|
|
|
|
print(f"Loading {base_model_name}")
|
|
base_state_dict = load_model(base_model_url)
|
|
|
|
for key in input_state_dict.keys():
|
|
if key in keys_to_skip_on_merge:
|
|
continue
|
|
|
|
if "model" in key:
|
|
if key in base_state_dict:
|
|
base_value = base_state_dict.get(
|
|
key, torch.zeros_like(input_state_dict[key])
|
|
)
|
|
input_state_dict[key] = input_state_dict[key] - base_value
|
|
else:
|
|
input_state_dict[key] = torch.zeros_like(input_state_dict[key])
|
|
|
|
del base_state_dict
|
|
|
|
print(f"Merging {inpainting_model_name} and the above difference")
|
|
inpainting_state_dict = load_model(inpainting_model_url)
|
|
for key in inpainting_state_dict.keys():
|
|
if input_state_dict and "model" in key and key in input_state_dict:
|
|
if key in keys_to_skip_on_merge:
|
|
continue
|
|
|
|
a = inpainting_state_dict[key]
|
|
b = input_state_dict[key]
|
|
|
|
if (
|
|
a.shape != b.shape
|
|
and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]
|
|
):
|
|
assert (
|
|
a.shape[1] == 9 and b.shape[1] == 4
|
|
), f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
|
inpainting_state_dict[key][:, 0:4, :, :] = add_difference(
|
|
a[:, 0:4, :, :], b, 1
|
|
)
|
|
result_is_inpainting_model = True
|
|
else:
|
|
inpainting_state_dict[key] = add_difference(a, b, 1)
|
|
|
|
del input_state_dict
|
|
|
|
print("Saving the model")
|
|
_, extension = os.path.splitext(output_model)
|
|
if extension.lower() == ".safetensors":
|
|
safetensors.torch.save_file(inpainting_state_dict, output_model)
|
|
else:
|
|
torch.save(inpainting_state_dict, output_model)
|
|
|
|
|
|
def main():
|
|
# Parse command-line arguments
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("model", type=str)
|
|
|
|
args = parser.parse_args()
|
|
|
|
generate(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|