adds a tool to generate inpainting model using a standard model. Method derived from AUTOMATIC1111
This commit is contained in:
parent
4a48ed43a7
commit
37a32ade71
|
|
@ -9,4 +9,6 @@ torch==2.0.0
|
||||||
transformers==4.28.1
|
transformers==4.28.1
|
||||||
sentencepiece==0.1.99
|
sentencepiece==0.1.99
|
||||||
Flask-Limiter==3.3.1
|
Flask-Limiter==3.3.1
|
||||||
protobuf==3.20
|
protobuf==3.20
|
||||||
|
safetensors==0.3.1
|
||||||
|
pytorch_lightning==2.0.2
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,149 @@
|
||||||
|
import os
|
||||||
|
import requests
|
||||||
|
import argparse
|
||||||
|
import safetensors
|
||||||
|
import torch
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# List of checkpoint dictionary keys to skip during merging
|
||||||
|
keys_to_skip_on_merge = [
|
||||||
|
"cond_stage_model.transformer.text_model.embeddings.position_ids"
|
||||||
|
]
|
||||||
|
|
||||||
|
# Dictionary for replacing keys in checkpoint dictionary
|
||||||
|
key_replacements = {
|
||||||
|
"cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.",
|
||||||
|
"cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.",
|
||||||
|
"cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def transform_checkpoint_key(key):
|
||||||
|
for text, replacement in key_replacements.items():
|
||||||
|
if key.startswith(text):
|
||||||
|
key = replacement + key[len(text) :]
|
||||||
|
return key
|
||||||
|
|
||||||
|
|
||||||
|
def get_state_dict_from_checkpoint(ckpt_dict):
|
||||||
|
ckpt_dict = ckpt_dict.pop("state_dict", ckpt_dict)
|
||||||
|
|
||||||
|
return {
|
||||||
|
transform_checkpoint_key(k): v
|
||||||
|
for k, v in ckpt_dict.items()
|
||||||
|
if transform_checkpoint_key(k) is not None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def download_model(url):
|
||||||
|
filepath = f"/tmp/{os.path.basename(url)}"
|
||||||
|
if os.path.isfile(filepath):
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
response = requests.get(url, stream=True)
|
||||||
|
total_size = int(response.headers.get("content-length", 0))
|
||||||
|
block_size = 1048576 # 1 MB
|
||||||
|
downloaded_size = 0
|
||||||
|
|
||||||
|
with open(filepath, "wb") as file:
|
||||||
|
for data in response.iter_content(block_size):
|
||||||
|
downloaded_size += len(data)
|
||||||
|
file.write(data)
|
||||||
|
# Calculate the progress
|
||||||
|
progress = downloaded_size / total_size * 100
|
||||||
|
print(f"Download progress: {progress:.2f}%")
|
||||||
|
return filepath
|
||||||
|
|
||||||
|
|
||||||
|
def load_model(checkpoint_file):
|
||||||
|
if "http" in checkpoint_file:
|
||||||
|
filepath = download_model(checkpoint_file)
|
||||||
|
else:
|
||||||
|
filepath = checkpoint_file
|
||||||
|
if not filepath:
|
||||||
|
raise ValueError(f"empty filepath for {checkpoint_file}")
|
||||||
|
|
||||||
|
_, extension = os.path.splitext(filepath)
|
||||||
|
if extension.lower() == ".safetensors":
|
||||||
|
model = safetensors.torch.load_file(filepath, device="cpu")
|
||||||
|
else:
|
||||||
|
model = torch.load(filepath, map_location="cpu")
|
||||||
|
|
||||||
|
return get_state_dict_from_checkpoint(model)
|
||||||
|
|
||||||
|
|
||||||
|
def generate(args):
|
||||||
|
input_model = args.model
|
||||||
|
inpainting_model_url = "https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt"
|
||||||
|
inpainting_model_name = os.path.basename(inpainting_model_url)
|
||||||
|
base_model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt"
|
||||||
|
base_model_name = os.path.basename(base_model_url)
|
||||||
|
output_model = f"{input_model.split('.')[0]}-inpainting.{input_model.split('.')[1]}"
|
||||||
|
|
||||||
|
print(f"Loading {input_model}")
|
||||||
|
input_state_dict = load_model(input_model)
|
||||||
|
|
||||||
|
print(f"Loading {base_model_name}")
|
||||||
|
base_state_dict = load_model(base_model_url)
|
||||||
|
|
||||||
|
for key in input_state_dict.keys():
|
||||||
|
if key in keys_to_skip_on_merge:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if "model" in key:
|
||||||
|
if key in base_state_dict:
|
||||||
|
base_value = base_state_dict.get(
|
||||||
|
key, torch.zeros_like(input_state_dict[key])
|
||||||
|
)
|
||||||
|
input_state_dict[key] = input_state_dict[key] - base_value
|
||||||
|
else:
|
||||||
|
input_state_dict[key] = torch.zeros_like(input_state_dict[key])
|
||||||
|
|
||||||
|
del base_state_dict
|
||||||
|
|
||||||
|
print(f"Merging {inpainting_model_name} and the above difference")
|
||||||
|
inpainting_state_dict = load_model(inpainting_model_url)
|
||||||
|
for key in inpainting_state_dict.keys():
|
||||||
|
if input_state_dict and "model" in key and key in input_state_dict:
|
||||||
|
if key in keys_to_skip_on_merge:
|
||||||
|
continue
|
||||||
|
|
||||||
|
a = inpainting_state_dict[key]
|
||||||
|
b = input_state_dict[key]
|
||||||
|
|
||||||
|
if (
|
||||||
|
a.shape != b.shape
|
||||||
|
and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
a.shape[1] == 9 and b.shape[1] == 4
|
||||||
|
), f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"
|
||||||
|
inpainting_state_dict[key][:, 0:4, :, :] = add_difference(
|
||||||
|
a[:, 0:4, :, :], b, 1
|
||||||
|
)
|
||||||
|
result_is_inpainting_model = True
|
||||||
|
else:
|
||||||
|
inpainting_state_dict[key] = add_difference(a, b, 1)
|
||||||
|
|
||||||
|
del input_state_dict
|
||||||
|
|
||||||
|
print("Saving the model")
|
||||||
|
_, extension = os.path.splitext(output_model)
|
||||||
|
if extension.lower() == ".safetensors":
|
||||||
|
safetensors.torch.save_file(inpainting_state_dict, output_model)
|
||||||
|
else:
|
||||||
|
torch.save(inpainting_state_dict, output_model)
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
# Parse command-line arguments
|
||||||
|
parser = argparse.ArgumentParser()
|
||||||
|
parser.add_argument("model", type=str)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
generate(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
Loading…
Reference in New Issue