[BE] supports checkpoint loading for inpainting models in backend
This commit is contained in:
parent
37a32ade71
commit
5de9b59cc3
|
|
@ -9,6 +9,7 @@ __pycache__/
|
|||
# sqlite3 db
|
||||
*.db
|
||||
data/
|
||||
model/
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
|
|
|
|||
16
backend.py
16
backend.py
|
|
@ -1,5 +1,6 @@
|
|||
import argparse
|
||||
import torch
|
||||
import os
|
||||
|
||||
from utilities.constants import LOGGER_NAME_BACKEND
|
||||
from utilities.constants import LOGGER_NAME_TXT2IMG
|
||||
|
|
@ -41,7 +42,7 @@ database = Database(logger)
|
|||
|
||||
|
||||
def load_model(
|
||||
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool
|
||||
logger: Logger, use_gpu: bool, gpu_device_name: str, reduce_memory_usage: bool, model_caching_folder_path: str
|
||||
) -> Model:
|
||||
# model candidates:
|
||||
# "runwayml/stable-diffusion-v1-5"
|
||||
|
|
@ -56,7 +57,7 @@ def load_model(
|
|||
model_name = "SG161222/Realistic_Vision_V2.0"
|
||||
# inpainting model candidates:
|
||||
# "runwayml/stable-diffusion-inpainting"
|
||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||
inpainting_model_name = "https://huggingface.co/SG161222/Realistic_Vision_V2.0/resolve/main/Realistic_Vision_V2.0-inpainting.ckpt"
|
||||
|
||||
model = Model(
|
||||
model_name,
|
||||
|
|
@ -64,6 +65,7 @@ def load_model(
|
|||
logger,
|
||||
use_gpu=use_gpu,
|
||||
gpu_device_name=gpu_device_name,
|
||||
model_caching_folder_path=model_caching_folder_path,
|
||||
)
|
||||
if use_gpu and reduce_memory_usage:
|
||||
model.set_low_memory_mode()
|
||||
|
|
@ -180,7 +182,10 @@ def main(args):
|
|||
database.set_image_output_folder(args.image_output_folder)
|
||||
database.connect(args.db)
|
||||
|
||||
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage)
|
||||
if not os.path.isdir(args.model_caching_folder):
|
||||
os.makedirs(args.model_caching_folder, exist_ok=True)
|
||||
|
||||
model = load_model(logger, args.gpu, args.gpu_device, args.reduce_memory_usage, args.model_caching_folder)
|
||||
backend(model, args.gfpgan, args.debug)
|
||||
|
||||
database.safe_disconnect()
|
||||
|
|
@ -205,6 +210,11 @@ if __name__ == "__main__":
|
|||
"--gpu-device", type=str, default="cuda", help="GPU device name"
|
||||
)
|
||||
|
||||
# Add an argument to set the gpu device name
|
||||
parser.add_argument(
|
||||
"--model-caching-folder", type=str, default="/tmp", help="Where to download models for caching"
|
||||
)
|
||||
|
||||
# Add an argument to reduce memory usage
|
||||
parser.add_argument(
|
||||
"--reduce-memory-usage",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
accelerate==0.18.0
|
||||
colorlog==6.7.0
|
||||
diffusers==0.15.1
|
||||
diffusers==0.16.1
|
||||
numpy==1.24.3
|
||||
Flask==2.3.1
|
||||
Pillow==9.0.1
|
||||
|
|
@ -12,3 +12,4 @@ Flask-Limiter==3.3.1
|
|||
protobuf==3.20
|
||||
safetensors==0.3.1
|
||||
pytorch_lightning==2.0.2
|
||||
omegaconf==2.3.0
|
||||
|
|
|
|||
|
|
@ -1,8 +1,14 @@
|
|||
import os
|
||||
from io import BytesIO
|
||||
import requests
|
||||
import diffusers
|
||||
import torch
|
||||
from diffusers import StableDiffusionPipeline
|
||||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
from diffusers.pipelines.stable_diffusion.convert_from_ckpt import (
|
||||
download_from_original_stable_diffusion_ckpt,
|
||||
)
|
||||
|
||||
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||
from utilities.constants import VALUE_SCHEDULER_DDIM
|
||||
|
|
@ -15,6 +21,26 @@ from utilities.memory import empty_memory_cache
|
|||
from utilities.memory import tune_for_low_memory
|
||||
|
||||
|
||||
def download_model(url, output_folder):
|
||||
filepath = f"{output_folder}/{os.path.basename(url)}"
|
||||
if os.path.isfile(filepath):
|
||||
return filepath
|
||||
|
||||
response = requests.get(url, stream=True)
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 1048576 # 1 MB
|
||||
downloaded_size = 0
|
||||
|
||||
with open(filepath, "wb") as file:
|
||||
for data in response.iter_content(block_size):
|
||||
downloaded_size += len(data)
|
||||
file.write(data)
|
||||
# Calculate the progress
|
||||
progress = downloaded_size / total_size * 100
|
||||
print(f"Download progress: {progress:.2f}%")
|
||||
return filepath
|
||||
|
||||
|
||||
class Model:
|
||||
"""Model class."""
|
||||
|
||||
|
|
@ -25,6 +51,7 @@ class Model:
|
|||
logger: DummyLogger = DummyLogger(),
|
||||
use_gpu: bool = True,
|
||||
gpu_device_name: str = "cuda",
|
||||
model_caching_folder_path: str = "/tmp",
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.inpainting_model_name = inpainting_model_name
|
||||
|
|
@ -32,11 +59,14 @@ class Model:
|
|||
self.__gpu_device = gpu_device_name
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.__use_gpu = True
|
||||
logger.info("running on {}".format(torch.cuda.get_device_name(self.__gpu_device)))
|
||||
logger.info(
|
||||
"running on {}".format(torch.cuda.get_device_name(self.__gpu_device))
|
||||
)
|
||||
else:
|
||||
logger.info("running on CPU (expect it to be verrry sloooow)")
|
||||
self.__logger = logger
|
||||
self.__torch_dtype = torch.float64
|
||||
self.__model_caching_folder_path = model_caching_folder_path
|
||||
|
||||
# txt2img and img2img are always loaded together
|
||||
self.txt2img_pipeline = None
|
||||
|
|
@ -49,7 +79,7 @@ class Model:
|
|||
def get_gpu_device_name(self):
|
||||
return self.__gpu_device
|
||||
|
||||
def update_model_name(self, model_name:str):
|
||||
def update_model_name(self, model_name: str):
|
||||
if not model_name or model_name == self.model_name:
|
||||
self.__logger.warn("model name empty or the same, not updated")
|
||||
return
|
||||
|
|
@ -61,7 +91,7 @@ class Model:
|
|||
tune_for_low_memory()
|
||||
self.__torch_dtype = torch.float16
|
||||
|
||||
def __set_scheduler(self, scheduler:str, pipeline, default_scheduler):
|
||||
def __set_scheduler(self, scheduler: str, pipeline, default_scheduler):
|
||||
if scheduler == VALUE_SCHEDULER_DEFAULT:
|
||||
pipeline.scheduler = default_scheduler
|
||||
return
|
||||
|
|
@ -70,32 +100,40 @@ class Model:
|
|||
|
||||
empty_memory_cache()
|
||||
|
||||
def set_img2img_scheduler(self, scheduler:str):
|
||||
def set_img2img_scheduler(self, scheduler: str):
|
||||
# note the change here also affects txt2img scheduler
|
||||
if self.img2img_pipeline is None:
|
||||
self.__logger.error("no img2img pipeline loaded, unable to set scheduler")
|
||||
return
|
||||
self.__set_scheduler(scheduler, self.img2img_pipeline, self.__default_img2img_scheduler)
|
||||
self.__set_scheduler(
|
||||
scheduler, self.img2img_pipeline, self.__default_img2img_scheduler
|
||||
)
|
||||
|
||||
def set_txt2img_scheduler(self, scheduler:str):
|
||||
def set_txt2img_scheduler(self, scheduler: str):
|
||||
# note the change here also affects img2img scheduler
|
||||
if self.txt2img_pipeline is None:
|
||||
self.__logger.error("no txt2img pipeline loaded, unable to set scheduler")
|
||||
return
|
||||
self.__set_scheduler(scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler)
|
||||
self.__set_scheduler(
|
||||
scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler
|
||||
)
|
||||
|
||||
def set_inpaint_scheduler(self, scheduler:str):
|
||||
def set_inpaint_scheduler(self, scheduler: str):
|
||||
if self.inpaint_pipeline is None:
|
||||
self.__logger.error("no inpaint pipeline loaded, unable to set scheduler")
|
||||
return
|
||||
self.__set_scheduler(scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler)
|
||||
self.__set_scheduler(
|
||||
scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler
|
||||
)
|
||||
|
||||
def load_txt2img_and_img2img_pipeline(self, force_reload:bool=False):
|
||||
def load_txt2img_and_img2img_pipeline(self, force_reload: bool = False):
|
||||
if (not force_reload) and (self.txt2img_pipeline is not None):
|
||||
self.__logger.warn("txt2img and img2img pipelines already loaded")
|
||||
return
|
||||
if not self.model_name:
|
||||
self.__logger.error("unable to load txt2img and img2img pipelines, model not set")
|
||||
self.__logger.error(
|
||||
"unable to load txt2img and img2img pipelines, model not set"
|
||||
)
|
||||
return
|
||||
revision = get_revision_from_model_name(self.model_name)
|
||||
pipeline = None
|
||||
|
|
@ -123,25 +161,58 @@ class Model:
|
|||
self.txt2img_pipeline = pipeline
|
||||
self.__default_txt2img_scheduler = pipeline.scheduler
|
||||
|
||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
|
||||
**pipeline.components
|
||||
)
|
||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**pipeline.components)
|
||||
self.__default_img2img_scheduler = self.__default_txt2img_scheduler
|
||||
|
||||
empty_memory_cache()
|
||||
|
||||
def load_inpaint_pipeline(self, force_reload:bool=False):
|
||||
def load_inpaint_pipeline(self, force_reload: bool = False):
|
||||
if (not force_reload) and (self.inpaint_pipeline is not None):
|
||||
self.__logger.warn("inpaint pipeline already loaded")
|
||||
return
|
||||
if not self.inpainting_model_name:
|
||||
self.__logger.error("unable to load inpaint pipeline, model not set")
|
||||
return
|
||||
revision = get_revision_from_model_name(self.inpainting_model_name)
|
||||
|
||||
pipeline = None
|
||||
|
||||
_, extension = os.path.splitext(self.inpainting_model_name)
|
||||
if extension.lower() == ".ckpt":
|
||||
if not os.path.isfile(self.inpainting_model_name):
|
||||
model_filepath = download_model(
|
||||
self.inpainting_model_name, self.__model_caching_folder_path
|
||||
)
|
||||
else:
|
||||
model_filepath = self.inpainting_model_name
|
||||
original_config_file = BytesIO(requests.get("https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml").content)
|
||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||
model_filepath,
|
||||
original_config_file=original_config_file,
|
||||
load_safety_checker=False,
|
||||
pipeline_class=StableDiffusionInpaintPipeline,
|
||||
device="cpu" if not self.use_gpu() else self.get_gpu_device_name(),
|
||||
)
|
||||
elif extension.lower() == ".safetensors":
|
||||
if not os.path.isfile(self.inpainting_model_name):
|
||||
model_filepath = download_model(
|
||||
self.inpainting_model_name, self.__model_caching_folder_path
|
||||
)
|
||||
else:
|
||||
model_filepath = self.inpainting_model_name
|
||||
original_config_file = BytesIO(requests.get("https://raw.githubusercontent.com/runwayml/stable-diffusion/main/configs/stable-diffusion/v1-inpainting-inference.yaml").content)
|
||||
pipeline = download_from_original_stable_diffusion_ckpt(
|
||||
model_filepath,
|
||||
original_config_file=original_config_file,
|
||||
from_safetensors=True,
|
||||
load_safety_checker=False,
|
||||
pipeline_class=StableDiffusionInpaintPipeline,
|
||||
device="cpu" if not self.use_gpu() else self.get_gpu_device_name(),
|
||||
)
|
||||
else:
|
||||
revision = get_revision_from_model_name(self.inpainting_model_name)
|
||||
try:
|
||||
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
model_name,
|
||||
self.inpainting_model_name,
|
||||
revision=revision,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None,
|
||||
|
|
@ -158,7 +229,8 @@ class Model:
|
|||
"failed to load inpaint model %s: %s"
|
||||
% (self.inpainting_model_name, e)
|
||||
)
|
||||
if pipeline and self.use_gpu():
|
||||
if pipeline:
|
||||
if self.use_gpu():
|
||||
pipeline.to(self.get_gpu_device_name())
|
||||
self.inpaint_pipeline = pipeline
|
||||
self.__default_inpaint_scheduler = pipeline.scheduler
|
||||
|
|
|
|||
Loading…
Reference in New Issue