107 lines
3.6 KiB
Python
107 lines
3.6 KiB
Python
import torch
|
|
from diffusers import StableDiffusionPipeline
|
|
from diffusers import StableDiffusionImg2ImgPipeline
|
|
from diffusers import StableDiffusionInpaintPipeline
|
|
|
|
from utilities.logger import DummyLogger
|
|
from utilities.memory import empty_memory
|
|
from utilities.memory import tune_for_low_memory
|
|
|
|
|
|
class Model:
|
|
"""Model class."""
|
|
|
|
def __init__(
|
|
self,
|
|
model_name: str,
|
|
inpainting_model_name: str,
|
|
logger: DummyLogger = DummyLogger(),
|
|
use_gpu: bool = True,
|
|
):
|
|
self.model_name = model_name
|
|
self.inpainting_model_name = inpainting_model_name
|
|
self.__use_gpu = False
|
|
if use_gpu and torch.cuda.is_available():
|
|
self.__use_gpu = True
|
|
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
|
self.__logger = logger
|
|
self.__torch_dtype = "auto"
|
|
|
|
self.sd_pipeline = None
|
|
self.img2img_pipeline = None
|
|
self.inpaint_pipeline = None
|
|
|
|
def use_gpu(self):
|
|
return self.__use_gpu
|
|
|
|
def reduce_memory(self):
|
|
self.__logger.info("reduces memory usage by using float16 dtype")
|
|
tune_for_low_memory()
|
|
self.__torch_dtype = torch.float16
|
|
|
|
def load(self):
|
|
empty_memory()
|
|
|
|
if self.model_name:
|
|
revision = get_revision_from_model_name(self.model_name)
|
|
sd_pipeline = None
|
|
try:
|
|
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
|
model_name,
|
|
revision=revision,
|
|
torch_dtype=self.__torch_dtype,
|
|
safety_checker=None,
|
|
)
|
|
except:
|
|
try:
|
|
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
|
self.model_name,
|
|
torch_dtype=self.__torch_dtype,
|
|
safety_checker=None,
|
|
)
|
|
except Exception as e:
|
|
self.__logger.error(
|
|
"failed to load model %s: %s" % (self.model_name, e)
|
|
)
|
|
if sd_pipeline and self.use_gpu():
|
|
sd_pipeline.to("cuda")
|
|
self.sd_pipeline = sd_pipeline
|
|
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
|
|
**sd_pipeline.components
|
|
)
|
|
|
|
if self.inpainting_model_name:
|
|
revision = get_revision_from_model_name(self.inpainting_model_name)
|
|
inpaint_pipeline = None
|
|
try:
|
|
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
|
model_name,
|
|
revision=revision,
|
|
torch_dtype=self.__torch_dtype,
|
|
safety_checker=None,
|
|
)
|
|
except:
|
|
try:
|
|
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
|
self.inpainting_model_name,
|
|
torch_dtype=self.__torch_dtype,
|
|
safety_checker=None,
|
|
)
|
|
except Exception as e:
|
|
self.__logger.error(
|
|
"failed to load inpaint model %s: %s"
|
|
% (self.inpainting_model_name, e)
|
|
)
|
|
if inpaint_pipeline and self.use_gpu():
|
|
inpaint_pipeline.to("cuda")
|
|
self.inpaint_pipeline = inpaint_pipeline
|
|
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler
|
|
|
|
|
|
def get_revision_from_model_name(model_name: str):
|
|
return (
|
|
"diffusers-115k"
|
|
if model_name == "naclbit/trinart_stable_diffusion_v2"
|
|
else "fp16"
|
|
)
|