adds a tool to generate inpainting model using a standard model. Method derived from AUTOMATIC1111

This commit is contained in:
HappyZ 2023-06-01 15:13:45 -07:00
parent 4a48ed43a7
commit 37a32ade71
2 changed files with 152 additions and 1 deletions

View File

@ -10,3 +10,5 @@ transformers==4.28.1
sentencepiece==0.1.99
Flask-Limiter==3.3.1
protobuf==3.20
safetensors==0.3.1
pytorch_lightning==2.0.2

View File

@ -0,0 +1,149 @@
import os
import requests
import argparse
import safetensors
import torch
import sys
# List of checkpoint dictionary keys to skip during merging
keys_to_skip_on_merge = [
"cond_stage_model.transformer.text_model.embeddings.position_ids"
]
# Dictionary for replacing keys in checkpoint dictionary
key_replacements = {
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.",
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.",
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.",
}
def transform_checkpoint_key(key):
for text, replacement in key_replacements.items():
if key.startswith(text):
key = replacement + key[len(text) :]
return key
def get_state_dict_from_checkpoint(ckpt_dict):
ckpt_dict = ckpt_dict.pop("state_dict", ckpt_dict)
return {
transform_checkpoint_key(k): v
for k, v in ckpt_dict.items()
if transform_checkpoint_key(k) is not None
}
def download_model(url):
filepath = f"/tmp/{os.path.basename(url)}"
if os.path.isfile(filepath):
return filepath
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1048576 # 1 MB
downloaded_size = 0
with open(filepath, "wb") as file:
for data in response.iter_content(block_size):
downloaded_size += len(data)
file.write(data)
# Calculate the progress
progress = downloaded_size / total_size * 100
print(f"Download progress: {progress:.2f}%")
return filepath
def load_model(checkpoint_file):
if "http" in checkpoint_file:
filepath = download_model(checkpoint_file)
else:
filepath = checkpoint_file
if not filepath:
raise ValueError(f"empty filepath for {checkpoint_file}")
_, extension = os.path.splitext(filepath)
if extension.lower() == ".safetensors":
model = safetensors.torch.load_file(filepath, device="cpu")
else:
model = torch.load(filepath, map_location="cpu")
return get_state_dict_from_checkpoint(model)
def generate(args):
input_model = args.model
inpainting_model_url = "https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt"
inpainting_model_name = os.path.basename(inpainting_model_url)
base_model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
base_model_name = os.path.basename(base_model_url)
output_model = f"{input_model.split('.')[0]}-inpainting.{input_model.split('.')[1]}"
print(f"Loading {input_model}")
input_state_dict = load_model(input_model)
print(f"Loading {base_model_name}")
base_state_dict = load_model(base_model_url)
for key in input_state_dict.keys():
if key in keys_to_skip_on_merge:
continue
if "model" in key:
if key in base_state_dict:
base_value = base_state_dict.get(
key, torch.zeros_like(input_state_dict[key])
)
input_state_dict[key] = input_state_dict[key] - base_value
else:
input_state_dict[key] = torch.zeros_like(input_state_dict[key])
del base_state_dict
print(f"Merging {inpainting_model_name} and the above difference")
inpainting_state_dict = load_model(inpainting_model_url)
for key in inpainting_state_dict.keys():
if input_state_dict and "model" in key and key in input_state_dict:
if key in keys_to_skip_on_merge:
continue
a = inpainting_state_dict[key]
b = input_state_dict[key]
if (
a.shape != b.shape
and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]
):
assert (
a.shape[1] == 9 and b.shape[1] == 4
), f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
inpainting_state_dict[key][:, 0:4, :, :] = add_difference(
a[:, 0:4, :, :], b, 1
)
result_is_inpainting_model = True
else:
inpainting_state_dict[key] = add_difference(a, b, 1)
del input_state_dict
print("Saving the model")
_, extension = os.path.splitext(output_model)
if extension.lower() == ".safetensors":
safetensors.torch.save_file(inpainting_state_dict, output_model)
else:
torch.save(inpainting_state_dict, output_model)
def main():
# Parse command-line arguments
parser = argparse.ArgumentParser()
parser.add_argument("model", type=str)
args = parser.parse_args()
generate(args)
if __name__ == "__main__":
main()