adds a tool to generate inpainting model using a standard model. Method derived from AUTOMATIC1111
This commit is contained in:
parent
4a48ed43a7
commit
37a32ade71
|
|
@ -9,4 +9,6 @@ torch==2.0.0
|
|||
transformers==4.28.1
|
||||
sentencepiece==0.1.99
|
||||
Flask-Limiter==3.3.1
|
||||
protobuf==3.20
|
||||
protobuf==3.20
|
||||
safetensors==0.3.1
|
||||
pytorch_lightning==2.0.2
|
||||
|
|
|
|||
|
|
@ -0,0 +1,149 @@
|
|||
import os
|
||||
import requests
|
||||
import argparse
|
||||
import safetensors
|
||||
import torch
|
||||
import sys
|
||||
|
||||
# List of checkpoint dictionary keys to skip during merging
|
||||
keys_to_skip_on_merge = [
|
||||
"cond_stage_model.transformer.text_model.embeddings.position_ids"
|
||||
]
|
||||
|
||||
# Dictionary for replacing keys in checkpoint dictionary
|
||||
key_replacements = {
|
||||
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.",
|
||||
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.",
|
||||
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.",
|
||||
}
|
||||
|
||||
|
||||
def transform_checkpoint_key(key):
|
||||
for text, replacement in key_replacements.items():
|
||||
if key.startswith(text):
|
||||
key = replacement + key[len(text) :]
|
||||
return key
|
||||
|
||||
|
||||
def get_state_dict_from_checkpoint(ckpt_dict):
|
||||
ckpt_dict = ckpt_dict.pop("state_dict", ckpt_dict)
|
||||
|
||||
return {
|
||||
transform_checkpoint_key(k): v
|
||||
for k, v in ckpt_dict.items()
|
||||
if transform_checkpoint_key(k) is not None
|
||||
}
|
||||
|
||||
|
||||
def download_model(url):
|
||||
filepath = f"/tmp/{os.path.basename(url)}"
|
||||
if os.path.isfile(filepath):
|
||||
return filepath
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 1048576 # 1 MB
|
||||
downloaded_size = 0
|
||||
|
||||
with open(filepath, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
downloaded_size += len(data)
|
||||
file.write(data)
|
||||
# Calculate the progress
|
||||
progress = downloaded_size / total_size * 100
|
||||
print(f"Download progress: {progress:.2f}%")
|
||||
return filepath
|
||||
|
||||
|
||||
def load_model(checkpoint_file):
|
||||
if "http" in checkpoint_file:
|
||||
filepath = download_model(checkpoint_file)
|
||||
else:
|
||||
filepath = checkpoint_file
|
||||
if not filepath:
|
||||
raise ValueError(f"empty filepath for {checkpoint_file}")
|
||||
|
||||
_, extension = os.path.splitext(filepath)
|
||||
if extension.lower() == ".safetensors":
|
||||
model = safetensors.torch.load_file(filepath, device="cpu")
|
||||
else:
|
||||
model = torch.load(filepath, map_location="cpu")
|
||||
|
||||
return get_state_dict_from_checkpoint(model)
|
||||
|
||||
|
||||
def generate(args):
|
||||
input_model = args.model
|
||||
inpainting_model_url = "https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt"
|
||||
inpainting_model_name = os.path.basename(inpainting_model_url)
|
||||
base_model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
|
||||
base_model_name = os.path.basename(base_model_url)
|
||||
output_model = f"{input_model.split('.')[0]}-inpainting.{input_model.split('.')[1]}"
|
||||
|
||||
print(f"Loading {input_model}")
|
||||
input_state_dict = load_model(input_model)
|
||||
|
||||
print(f"Loading {base_model_name}")
|
||||
base_state_dict = load_model(base_model_url)
|
||||
|
||||
for key in input_state_dict.keys():
|
||||
if key in keys_to_skip_on_merge:
|
||||
continue
|
||||
|
||||
if "model" in key:
|
||||
if key in base_state_dict:
|
||||
base_value = base_state_dict.get(
|
||||
key, torch.zeros_like(input_state_dict[key])
|
||||
)
|
||||
input_state_dict[key] = input_state_dict[key] - base_value
|
||||
else:
|
||||
input_state_dict[key] = torch.zeros_like(input_state_dict[key])
|
||||
|
||||
del base_state_dict
|
||||
|
||||
print(f"Merging {inpainting_model_name} and the above difference")
|
||||
inpainting_state_dict = load_model(inpainting_model_url)
|
||||
for key in inpainting_state_dict.keys():
|
||||
if input_state_dict and "model" in key and key in input_state_dict:
|
||||
if key in keys_to_skip_on_merge:
|
||||
continue
|
||||
|
||||
a = inpainting_state_dict[key]
|
||||
b = input_state_dict[key]
|
||||
|
||||
if (
|
||||
a.shape != b.shape
|
||||
and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]
|
||||
):
|
||||
assert (
|
||||
a.shape[1] == 9 and b.shape[1] == 4
|
||||
), f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||
inpainting_state_dict[key][:, 0:4, :, :] = add_difference(
|
||||
a[:, 0:4, :, :], b, 1
|
||||
)
|
||||
result_is_inpainting_model = True
|
||||
else:
|
||||
inpainting_state_dict[key] = add_difference(a, b, 1)
|
||||
|
||||
del input_state_dict
|
||||
|
||||
print("Saving the model")
|
||||
_, extension = os.path.splitext(output_model)
|
||||
if extension.lower() == ".safetensors":
|
||||
safetensors.torch.save_file(inpainting_state_dict, output_model)
|
||||
else:
|
||||
torch.save(inpainting_state_dict, output_model)
|
||||
|
||||
|
||||
def main():
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("model", type=str)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
generate(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue