[BE] supports checkpoint loading for inpainting models in backend

This commit is contained in:
HappyZ 2023-06-01 17:18:54 -07:00
parent 37a32ade71
commit 5de9b59cc3
4 changed files with 125 additions and 41 deletions

1
.gitignore vendored
View File

@ -9,6 +9,7 @@ __pycache__/
# sqlite3 db # sqlite3 db
*.db *.db
data/ data/
model/
# Distribution / packaging # Distribution / packaging
.Python .Python

View File

@ -1,5 +1,6 @@
import argparse import argparse
import torch import torch
import os
from utilities.constants import LOGGER_NAME_BACKEND from utilities.constants import LOGGER_NAME_BACKEND
from utilities.constants import LOGGER_NAME_TXT2IMG from utilities.constants import LOGGER_NAME_TXT2IMG
@ -41,7 +42,7 @@ database = Database(logger)
def load_model( def load_model(
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool, model_caching_folder_path: str
) -> Model: ) -> Model:
# model candidates: # model candidates:
# "runwayml/stable-diffusion-v1-5" # "runwayml/stable-diffusion-v1-5"
@ -56,7 +57,7 @@ def load_model(
model_name = "SG161222/Realistic_Vision_V2.0" model_name = "SG161222/Realistic_Vision_V2.0"
# inpainting model candidates: # inpainting model candidates:
# "runwayml/stable-diffusion-inpainting" # "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting" inpainting_model_name = "https://huggingface.co/SG161222/Realistic_Vision_V2.0/resolve/main/Realistic_Vision_V2.0-inpainting.ckpt"
model = Model( model = Model(
model_name, model_name,
@ -64,6 +65,7 @@ def load_model(
logger, logger,
use_gpu=use_gpu, use_gpu=use_gpu,
gpu_device_name=gpu_device_name, gpu_device_name=gpu_device_name,
model_caching_folder_path=model_caching_folder_path,
) )
if use_gpu and reduce_memory_usage: if use_gpu and reduce_memory_usage:
model.set_low_memory_mode() model.set_low_memory_mode()
@ -180,7 +182,10 @@ def main(args):
database.set_image_output_folder(args.image_output_folder) database.set_image_output_folder(args.image_output_folder)
database.connect(args.db) database.connect(args.db)
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage) if not os.path.isdir(args.model_caching_folder):
os.makedirs(args.model_caching_folder, exist_ok=True)
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage, args.model_caching_folder)
backend(model, args.gfpgan, args.debug) backend(model, args.gfpgan, args.debug)
database.safe_disconnect() database.safe_disconnect()
@ -205,6 +210,11 @@ if __name__ == "__main__":
"--gpu-device", type=str, default="cuda", help="GPU device name" "--gpu-device", type=str, default="cuda", help="GPU device name"
) )
# Add an argument to set the gpu device name
parser.add_argument(
"--model-caching-folder", type=str, default="/tmp", help="Where to download models for caching"
)
# Add an argument to reduce memory usage # Add an argument to reduce memory usage
parser.add_argument( parser.add_argument(
"--reduce-memory-usage", "--reduce-memory-usage",

View File

@ -1,6 +1,6 @@
accelerate==0.18.0 accelerate==0.18.0
colorlog==6.7.0 colorlog==6.7.0
diffusers==0.15.1 diffusers==0.16.1
numpy==1.24.3 numpy==1.24.3
Flask==2.3.1 Flask==2.3.1
Pillow==9.0.1 Pillow==9.0.1
@ -12,3 +12,4 @@ Flask-Limiter==3.3.1
protobuf==3.20 protobuf==3.20
safetensors==0.3.1 safetensors==0.3.1
pytorch_lightning==2.0.2 pytorch_lightning==2.0.2
omegaconf==2.3.0

View File

@ -1,8 +1,14 @@
import os
from io import BytesIO
import requests
import diffusers import diffusers
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline from diffusers import StableDiffusionInpaintPipeline
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
download_from_original_stable_diffusion_ckpt,
)
from utilities.constants import VALUE_SCHEDULER_DEFAULT from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM from utilities.constants import VALUE_SCHEDULER_DDIM
@ -15,6 +21,26 @@ from utilities.memory import empty_memory_cache
from utilities.memory import tune_for_low_memory from utilities.memory import tune_for_low_memory
def download_model(url, output_folder):
filepath = f"{output_folder}/{os.path.basename(url)}"
if os.path.isfile(filepath):
return filepath
response = requests.get(url, stream=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1048576 # 1 MB
downloaded_size = 0
with open(filepath, "wb") as file:
for data in response.iter_content(block_size):
downloaded_size += len(data)
file.write(data)
# Calculate the progress
progress = downloaded_size / total_size * 100
print(f"Download progress: {progress:.2f}%")
return filepath
class Model: class Model:
"""Model class.""" """Model class."""
@ -25,6 +51,7 @@ class Model:
logger: DummyLogger = DummyLogger(), logger: DummyLogger = DummyLogger(),
use_gpu: bool = True, use_gpu: bool = True,
gpu_device_name: str = "cuda", gpu_device_name: str = "cuda",
model_caching_folder_path: str = "/tmp",
): ):
self.model_name = model_name self.model_name = model_name
self.inpainting_model_name = inpainting_model_name self.inpainting_model_name = inpainting_model_name
@ -32,11 +59,14 @@ class Model:
self.__gpu_device = gpu_device_name self.__gpu_device = gpu_device_name
if use_gpu and torch.cuda.is_available(): if use_gpu and torch.cuda.is_available():
self.__use_gpu = True self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name(self.__gpu_device))) logger.info(
"running on {}".format(torch.cuda.get_device_name(self.__gpu_device))
)
else: else:
logger.info("running on CPU (expect it to be verrry sloooow)") logger.info("running on CPU (expect it to be verrry sloooow)")
self.__logger = logger self.__logger = logger
self.__torch_dtype = torch.float64 self.__torch_dtype = torch.float64
self.__model_caching_folder_path = model_caching_folder_path
# txt2img and img2img are always loaded together # txt2img and img2img are always loaded together
self.txt2img_pipeline = None self.txt2img_pipeline = None
@ -45,11 +75,11 @@ class Model:
def use_gpu(self): def use_gpu(self):
return self.__use_gpu return self.__use_gpu
def get_gpu_device_name(self): def get_gpu_device_name(self):
return self.__gpu_device return self.__gpu_device
def update_model_name(self, model_name:str): def update_model_name(self, model_name: str):
if not model_name or model_name == self.model_name: if not model_name or model_name == self.model_name:
self.__logger.warn("model name empty or the same, not updated") self.__logger.warn("model name empty or the same, not updated")
return return
@ -60,8 +90,8 @@ class Model:
self.__logger.info("reduces memory usage by using float16 dtype") self.__logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory() tune_for_low_memory()
self.__torch_dtype = torch.float16 self.__torch_dtype = torch.float16
def __set_scheduler(self, scheduler:str, pipeline, default_scheduler): def __set_scheduler(self, scheduler: str, pipeline, default_scheduler):
if scheduler == VALUE_SCHEDULER_DEFAULT: if scheduler == VALUE_SCHEDULER_DEFAULT:
pipeline.scheduler = default_scheduler pipeline.scheduler = default_scheduler
return return
@ -70,32 +100,40 @@ class Model:
empty_memory_cache() empty_memory_cache()
def set_img2img_scheduler(self, scheduler:str): def set_img2img_scheduler(self, scheduler: str):
# note the change here also affects txt2img scheduler # note the change here also affects txt2img scheduler
if self.img2img_pipeline is None: if self.img2img_pipeline is None:
self.__logger.error("no img2img pipeline loaded, unable to set scheduler") self.__logger.error("no img2img pipeline loaded, unable to set scheduler")
return return
self.__set_scheduler(scheduler, self.img2img_pipeline, self.__default_img2img_scheduler) self.__set_scheduler(
scheduler, self.img2img_pipeline, self.__default_img2img_scheduler
)
def set_txt2img_scheduler(self, scheduler:str): def set_txt2img_scheduler(self, scheduler: str):
# note the change here also affects img2img scheduler # note the change here also affects img2img scheduler
if self.txt2img_pipeline is None: if self.txt2img_pipeline is None:
self.__logger.error("no txt2img pipeline loaded, unable to set scheduler") self.__logger.error("no txt2img pipeline loaded, unable to set scheduler")
return return
self.__set_scheduler(scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler) self.__set_scheduler(
scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler
)
def set_inpaint_scheduler(self, scheduler:str): def set_inpaint_scheduler(self, scheduler: str):
if self.inpaint_pipeline is None: if self.inpaint_pipeline is None:
self.__logger.error("no inpaint pipeline loaded, unable to set scheduler") self.__logger.error("no inpaint pipeline loaded, unable to set scheduler")
return return
self.__set_scheduler(scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler) self.__set_scheduler(
scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler
def load_txt2img_and_img2img_pipeline(self, force_reload:bool=False): )
def load_txt2img_and_img2img_pipeline(self, force_reload: bool = False):
if (not force_reload) and (self.txt2img_pipeline is not None): if (not force_reload) and (self.txt2img_pipeline is not None):
self.__logger.warn("txt2img and img2img pipelines already loaded") self.__logger.warn("txt2img and img2img pipelines already loaded")
return return
if not self.model_name: if not self.model_name:
self.__logger.error("unable to load txt2img and img2img pipelines, model not set") self.__logger.error(
"unable to load txt2img and img2img pipelines, model not set"
)
return return
revision = get_revision_from_model_name(self.model_name) revision = get_revision_from_model_name(self.model_name)
pipeline = None pipeline = None
@ -119,49 +157,83 @@ class Model:
) )
if pipeline and self.use_gpu(): if pipeline and self.use_gpu():
pipeline.to(self.get_gpu_device_name()) pipeline.to(self.get_gpu_device_name())
self.txt2img_pipeline = pipeline self.txt2img_pipeline = pipeline
self.__default_txt2img_scheduler = pipeline.scheduler self.__default_txt2img_scheduler = pipeline.scheduler
self.img2img_pipeline = StableDiffusionImg2ImgPipeline( self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**pipeline.components)
**pipeline.components
)
self.__default_img2img_scheduler = self.__default_txt2img_scheduler self.__default_img2img_scheduler = self.__default_txt2img_scheduler
empty_memory_cache() empty_memory_cache()
def load_inpaint_pipeline(self, force_reload:bool=False): def load_inpaint_pipeline(self, force_reload: bool = False):
if (not force_reload) and (self.inpaint_pipeline is not None): if (not force_reload) and (self.inpaint_pipeline is not None):
self.__logger.warn("inpaint pipeline already loaded") self.__logger.warn("inpaint pipeline already loaded")
return return
if not self.inpainting_model_name: if not self.inpainting_model_name:
self.__logger.error("unable to load inpaint pipeline, model not set") self.__logger.error("unable to load inpaint pipeline, model not set")
return return
revision = get_revision_from_model_name(self.inpainting_model_name)
pipeline = None pipeline = None
try:
pipeline = StableDiffusionInpaintPipeline.from_pretrained( _, extension = os.path.splitext(self.inpainting_model_name)
model_name, if extension.lower() == ".ckpt":
revision=revision, if not os.path.isfile(self.inpainting_model_name):
torch_dtype=self.__torch_dtype, model_filepath = download_model(
safety_checker=None, self.inpainting_model_name, self.__model_caching_folder_path
)
else:
model_filepath = self.inpainting_model_name
original_config_file = BytesIO(requests.get("https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml").content)
pipeline = download_from_original_stable_diffusion_ckpt(
model_filepath,
original_config_file=original_config_file,
load_safety_checker=False,
pipeline_class=StableDiffusionInpaintPipeline,
device="cpu" if not self.use_gpu() else self.get_gpu_device_name(),
) )
except: elif extension.lower() == ".safetensors":
if not os.path.isfile(self.inpainting_model_name):
model_filepath = download_model(
self.inpainting_model_name, self.__model_caching_folder_path
)
else:
model_filepath = self.inpainting_model_name
original_config_file = BytesIO(requests.get("https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml").content)
pipeline = download_from_original_stable_diffusion_ckpt(
model_filepath,
original_config_file=original_config_file,
from_safetensors=True,
load_safety_checker=False,
pipeline_class=StableDiffusionInpaintPipeline,
device="cpu" if not self.use_gpu() else self.get_gpu_device_name(),
)
else:
revision = get_revision_from_model_name(self.inpainting_model_name)
try: try:
pipeline = StableDiffusionInpaintPipeline.from_pretrained( pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name, self.inpainting_model_name,
revision=revision,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None, safety_checker=None,
) )
except Exception as e: except:
self.__logger.error( try:
"failed to load inpaint model %s: %s" pipeline = StableDiffusionInpaintPipeline.from_pretrained(
% (self.inpainting_model_name, e) self.inpainting_model_name,
) torch_dtype=self.__torch_dtype,
if pipeline and self.use_gpu(): safety_checker=None,
pipeline.to(self.get_gpu_device_name()) )
self.inpaint_pipeline = pipeline except Exception as e:
self.__default_inpaint_scheduler = pipeline.scheduler self.__logger.error(
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if pipeline:
if self.use_gpu():
pipeline.to(self.get_gpu_device_name())
self.inpaint_pipeline = pipeline
self.__default_inpaint_scheduler = pipeline.scheduler
empty_memory_cache() empty_memory_cache()
def load_all(self): def load_all(self):