[BE] supports checkpoint loading for inpainting models in backend

This commit is contained in:
HappyZ 2023-06-01 17:18:54 -07:00
parent 37a32ade71
commit 5de9b59cc3
4 changed files with 125 additions and 41 deletions

1
.gitignore vendored
View File

@ -9,6 +9,7 @@ __pycache__/
# sqlite3 db
*.db
data/
model/
# Distribution / packaging
.Python

View File

@ -1,5 +1,6 @@
import argparse
import torch
import os
from utilities.constants import LOGGER_NAME_BACKEND
from utilities.constants import LOGGER_NAME_TXT2IMG
@ -41,7 +42,7 @@ database = Database(logger)
def load_model(
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool, model_caching_folder_path: str
) -> Model:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
@ -56,7 +57,7 @@ def load_model(
model_name = "SG161222/Realistic_Vision_V2.0"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "https://huggingface.co/SG161222/Realistic_Vision_V2.0/resolve/main/Realistic_Vision_V2.0-inpainting.ckpt"
model = Model(
model_name,
@ -64,6 +65,7 @@ def load_model(
logger,
use_gpu=use_gpu,
gpu_device_name=gpu_device_name,
model_caching_folder_path=model_caching_folder_path,
)
if use_gpu and reduce_memory_usage:
model.set_low_memory_mode()
@ -180,7 +182,10 @@ def main(args):
database.set_image_output_folder(args.image_output_folder)
database.connect(args.db)
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage)
if not os.path.isdir(args.model_caching_folder):
os.makedirs(args.model_caching_folder, exist_ok=True)
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage, args.model_caching_folder)
backend(model, args.gfpgan, args.debug)
database.safe_disconnect()
@ -205,6 +210,11 @@ if __name__ == "__main__":
"--gpu-device", type=str, default="cuda", help="GPU device name"
)
# Add an argument to set the gpu device name
parser.add_argument(
"--model-caching-folder", type=str, default="/tmp", help="Where to download models for caching"
)
# Add an argument to reduce memory usage
parser.add_argument(
"--reduce-memory-usage",

View File

@ -1,6 +1,6 @@
accelerate==0.18.0
colorlog==6.7.0
diffusers==0.15.1
diffusers==0.16.1
numpy==1.24.3
Flask==2.3.1
Pillow==9.0.1
@ -12,3 +12,4 @@ Flask-Limiter==3.3.1
protobuf==3.20
safetensors==0.3.1
pytorch_lightning==2.0.2
omegaconf==2.3.0

View File

@ -1,8 +1,14 @@
import os
from io import BytesIO
import requests
import diffusers
import torch
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM
@ -15,6 +21,26 @@ from utilities.memory import empty_memory_cache
from utilities.memory import tune_for_low_memory
def download_model(url, output_folder):
filepath = f"{output_folder}/{os.path.basename(url)}"
if os.path.isfile(filepath):
return filepath
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1048576 # 1 MB
downloaded_size = 0
with open(filepath, "wb") as file:
for data in response.iter_content(block_size):
downloaded_size += len(data)
file.write(data)
# Calculate the progress
progress = downloaded_size / total_size * 100
print(f"Download progress: {progress:.2f}%")
return filepath
class Model:
"""Model class."""
@ -25,6 +51,7 @@ class Model:
logger: DummyLogger = DummyLogger(),
use_gpu: bool = True,
gpu_device_name: str = "cuda",
model_caching_folder_path: str = "/tmp",
):
self.model_name = model_name
self.inpainting_model_name = inpainting_model_name
@ -32,11 +59,14 @@ class Model:
self.__gpu_device = gpu_device_name
if use_gpu and torch.cuda.is_available():
self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name(self.__gpu_device)))
logger.info(
"running on {}".format(torch.cuda.get_device_name(self.__gpu_device))
)
else:
logger.info("running on CPU (expect it to be verrry sloooow)")
self.__logger = logger
self.__torch_dtype = torch.float64
self.__model_caching_folder_path = model_caching_folder_path
# txt2img and img2img are always loaded together
self.txt2img_pipeline = None
@ -49,7 +79,7 @@ class Model:
def get_gpu_device_name(self):
return self.__gpu_device
def update_model_name(self, model_name:str):
def update_model_name(self, model_name: str):
if not model_name or model_name == self.model_name:
self.__logger.warn("model name empty or the same, not updated")
return
@ -61,7 +91,7 @@ class Model:
tune_for_low_memory()
self.__torch_dtype = torch.float16
def __set_scheduler(self, scheduler:str, pipeline, default_scheduler):
def __set_scheduler(self, scheduler: str, pipeline, default_scheduler):
if scheduler == VALUE_SCHEDULER_DEFAULT:
pipeline.scheduler = default_scheduler
return
@ -70,32 +100,40 @@ class Model:
empty_memory_cache()
def set_img2img_scheduler(self, scheduler:str):
def set_img2img_scheduler(self, scheduler: str):
# note the change here also affects txt2img scheduler
if self.img2img_pipeline is None:
self.__logger.error("no img2img pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.img2img_pipeline, self.__default_img2img_scheduler)
self.__set_scheduler(
scheduler, self.img2img_pipeline, self.__default_img2img_scheduler
)
def set_txt2img_scheduler(self, scheduler:str):
def set_txt2img_scheduler(self, scheduler: str):
# note the change here also affects img2img scheduler
if self.txt2img_pipeline is None:
self.__logger.error("no txt2img pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler)
self.__set_scheduler(
scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler
)
def set_inpaint_scheduler(self, scheduler:str):
def set_inpaint_scheduler(self, scheduler: str):
if self.inpaint_pipeline is None:
self.__logger.error("no inpaint pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler)
self.__set_scheduler(
scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler
)
def load_txt2img_and_img2img_pipeline(self, force_reload:bool=False):
def load_txt2img_and_img2img_pipeline(self, force_reload: bool = False):
if (not force_reload) and (self.txt2img_pipeline is not None):
self.__logger.warn("txt2img and img2img pipelines already loaded")
return
if not self.model_name:
self.__logger.error("unable to load txt2img and img2img pipelines, model not set")
self.__logger.error(
"unable to load txt2img and img2img pipelines, model not set"
)
return
revision = get_revision_from_model_name(self.model_name)
pipeline = None
@ -123,25 +161,58 @@ class Model:
self.txt2img_pipeline = pipeline
self.__default_txt2img_scheduler = pipeline.scheduler
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**pipeline.components
)
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**pipeline.components)
self.__default_img2img_scheduler = self.__default_txt2img_scheduler
empty_memory_cache()
def load_inpaint_pipeline(self, force_reload:bool=False):
def load_inpaint_pipeline(self, force_reload: bool = False):
if (not force_reload) and (self.inpaint_pipeline is not None):
self.__logger.warn("inpaint pipeline already loaded")
return
if not self.inpainting_model_name:
self.__logger.error("unable to load inpaint pipeline, model not set")
return
revision = get_revision_from_model_name(self.inpainting_model_name)
pipeline = None
_, extension = os.path.splitext(self.inpainting_model_name)
if extension.lower() == ".ckpt":
if not os.path.isfile(self.inpainting_model_name):
model_filepath = download_model(
self.inpainting_model_name, self.__model_caching_folder_path
)
else:
model_filepath = self.inpainting_model_name
original_config_file = BytesIO(requests.get("https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml").content)
pipeline = download_from_original_stable_diffusion_ckpt(
model_filepath,
original_config_file=original_config_file,
load_safety_checker=False,
pipeline_class=StableDiffusionInpaintPipeline,
device="cpu" if not self.use_gpu() else self.get_gpu_device_name(),
)
elif extension.lower() == ".safetensors":
if not os.path.isfile(self.inpainting_model_name):
model_filepath = download_model(
self.inpainting_model_name, self.__model_caching_folder_path
)
else:
model_filepath = self.inpainting_model_name
original_config_file = BytesIO(requests.get("https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml").content)
pipeline = download_from_original_stable_diffusion_ckpt(
model_filepath,
original_config_file=original_config_file,
from_safetensors=True,
load_safety_checker=False,
pipeline_class=StableDiffusionInpaintPipeline,
device="cpu" if not self.use_gpu() else self.get_gpu_device_name(),
)
else:
revision = get_revision_from_model_name(self.inpainting_model_name)
try:
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name,
self.inpainting_model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None,
@ -158,7 +229,8 @@ class Model:
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if pipeline and self.use_gpu():
if pipeline:
if self.use_gpu():
pipeline.to(self.get_gpu_device_name())
self.inpaint_pipeline = pipeline
self.__default_inpaint_scheduler = pipeline.scheduler