diff --git a/requirements.txt b/requirements.txt index 07fd7f6..b49e117 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,4 +9,6 @@ torch==2.0.0 transformers==4.28.1 sentencepiece==0.1.99 Flask-Limiter==3.3.1 -protobuf==3.20 \ No newline at end of file +protobuf==3.20 +safetensors==0.3.1 +pytorch_lightning==2.0.2 diff --git a/tools/sd_inpainting_model_generator.py b/tools/sd_inpainting_model_generator.py new file mode 100644 index 0000000..0041f9b --- /dev/null +++ b/tools/sd_inpainting_model_generator.py @@ -0,0 +1,149 @@ +import os +import requests +import argparse +import safetensors +import torch +import sys + +# List of checkpoint dictionary keys to skip during merging +keys_to_skip_on_merge = [ + "cond_stage_model.transformer.text_model.embeddings.position_ids" +] + +# Dictionary for replacing keys in checkpoint dictionary +key_replacements = { + "cond_stage_model.transformer.embeddings.": "cond_stage_model.transformer.text_model.embeddings.", + "cond_stage_model.transformer.encoder.": "cond_stage_model.transformer.text_model.encoder.", + "cond_stage_model.transformer.final_layer_norm.": "cond_stage_model.transformer.text_model.final_layer_norm.", +} + + +def transform_checkpoint_key(key): + for text, replacement in key_replacements.items(): + if key.startswith(text): + key = replacement + key[len(text) :] + return key + + +def get_state_dict_from_checkpoint(ckpt_dict): + ckpt_dict = ckpt_dict.pop("state_dict", ckpt_dict) + + return { + transform_checkpoint_key(k): v + for k, v in ckpt_dict.items() + if transform_checkpoint_key(k) is not None + } + + +def download_model(url): + filepath = f"/tmp/{os.path.basename(url)}" + if os.path.isfile(filepath): + return filepath + + response = requests.get(url, stream=True) + total_size = int(response.headers.get("content-length", 0)) + block_size = 1048576 # 1 MB + downloaded_size = 0 + + with open(filepath, "wb") as file: + for data in response.iter_content(block_size): + downloaded_size += len(data) + file.write(data) + # Calculate the progress + progress = downloaded_size / total_size * 100 + print(f"Download progress: {progress:.2f}%") + return filepath + + +def load_model(checkpoint_file): + if "http" in checkpoint_file: + filepath = download_model(checkpoint_file) + else: + filepath = checkpoint_file + if not filepath: + raise ValueError(f"empty filepath for {checkpoint_file}") + + _, extension = os.path.splitext(filepath) + if extension.lower() == ".safetensors": + model = safetensors.torch.load_file(filepath, device="cpu") + else: + model = torch.load(filepath, map_location="cpu") + + return get_state_dict_from_checkpoint(model) + + +def generate(args): + input_model = args.model + inpainting_model_url = "https://huggingface.co/runwayml/stable-diffusion-inpainting/resolve/main/sd-v1-5-inpainting.ckpt" + inpainting_model_name = os.path.basename(inpainting_model_url) + base_model_url = "https://huggingface.co/runwayml/stable-diffusion-v1-5/resolve/main/v1-5-pruned.ckpt" + base_model_name = os.path.basename(base_model_url) + output_model = f"{input_model.split('.')[0]}-inpainting.{input_model.split('.')[1]}" + + print(f"Loading {input_model}") + input_state_dict = load_model(input_model) + + print(f"Loading {base_model_name}") + base_state_dict = load_model(base_model_url) + + for key in input_state_dict.keys(): + if key in keys_to_skip_on_merge: + continue + + if "model" in key: + if key in base_state_dict: + base_value = base_state_dict.get( + key, torch.zeros_like(input_state_dict[key]) + ) + input_state_dict[key] = input_state_dict[key] - base_value + else: + input_state_dict[key] = torch.zeros_like(input_state_dict[key]) + + del base_state_dict + + print(f"Merging {inpainting_model_name} and the above difference") + inpainting_state_dict = load_model(inpainting_model_url) + for key in inpainting_state_dict.keys(): + if input_state_dict and "model" in key and key in input_state_dict: + if key in keys_to_skip_on_merge: + continue + + a = inpainting_state_dict[key] + b = input_state_dict[key] + + if ( + a.shape != b.shape + and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:] + ): + assert ( + a.shape[1] == 9 and b.shape[1] == 4 + ), f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}" + inpainting_state_dict[key][:, 0:4, :, :] = add_difference( + a[:, 0:4, :, :], b, 1 + ) + result_is_inpainting_model = True + else: + inpainting_state_dict[key] = add_difference(a, b, 1) + + del input_state_dict + + print("Saving the model") + _, extension = os.path.splitext(output_model) + if extension.lower() == ".safetensors": + safetensors.torch.save_file(inpainting_state_dict, output_model) + else: + torch.save(inpainting_state_dict, output_model) + + +def main(): + # Parse command-line arguments + parser = argparse.ArgumentParser() + parser.add_argument("model", type=str) + + args = parser.parse_args() + + generate(args) + + +if __name__ == "__main__": + main()