stable-diffusion-for-fun/utilities/model.py

107 lines
3.6 KiB
Python

import torch
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from utilities.logger import DummyLogger
from utilities.memory import empty_memory
from utilities.memory import tune_for_low_memory
class Model:
"""Model class."""
def __init__(
self,
model_name: str,
inpainting_model_name: str,
logger: DummyLogger = DummyLogger(),
use_gpu: bool = True,
):
self.model_name = model_name
self.inpainting_model_name = inpainting_model_name
self.__use_gpu = False
if use_gpu and torch.cuda.is_available():
self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
self.__logger = logger
self.__torch_dtype = "auto"
self.sd_pipeline = None
self.img2img_pipeline = None
self.inpaint_pipeline = None
def use_gpu(self):
return self.__use_gpu
def reduce_memory(self):
self.__logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory()
self.__torch_dtype = torch.float16
def load(self):
empty_memory()
if self.model_name:
revision = get_revision_from_model_name(self.model_name)
sd_pipeline = None
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except:
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except Exception as e:
self.__logger.error(
"failed to load model %s: %s" % (self.model_name, e)
)
if sd_pipeline and self.use_gpu():
sd_pipeline.to("cuda")
self.sd_pipeline = sd_pipeline
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**sd_pipeline.components
)
if self.inpainting_model_name:
revision = get_revision_from_model_name(self.inpainting_model_name)
inpaint_pipeline = None
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except:
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except Exception as e:
self.__logger.error(
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if inpaint_pipeline and self.use_gpu():
inpaint_pipeline.to("cuda")
self.inpaint_pipeline = inpaint_pipeline
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler
def get_revision_from_model_name(model_name: str):
return (
"diffusers-115k"
if model_name == "naclbit/trinart_stable_diffusion_v2"
else "fp16"
)