adds scheduler config and barebone of text2img

This commit is contained in:
HappyZ 2023-04-27 00:39:02 -07:00
parent d4e00c3377
commit 41664e2682
7 changed files with 136 additions and 72 deletions

View File

@ -21,8 +21,8 @@ def prepare(logger: Logger) -> [Model, Config]:
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger)
model.reduce_memory()
model.load()
model.set_low_memory_mode()
model.load_all()
config = Config()
return model, config

View File

@ -34,6 +34,7 @@ py_library(
name="model",
srcs=["model.py"],
deps=[
":constants",
":memory",
":logger",
],

View File

@ -9,6 +9,11 @@ from utilities.constants import KEY_PREVIEW
from utilities.constants import VALUE_PREVIEW_DEFAULT
from utilities.constants import KEY_SCHEDULER
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
from utilities.constants import VALUE_SCHEDULER_PNDM
from utilities.constants import KEY_SEED
from utilities.constants import VALUE_SEED_DEFAULT
from utilities.constants import KEY_STEPS
@ -55,15 +60,8 @@ class Config:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
def set_scheduler(self, scheduler: str):
# choices:
# "Default"
# "DPMSolverMultistepScheduler"
# "LMSDiscreteScheduler"
# "EulerDiscreteScheduler"
# "PNDMScheduler"
# "DDIMScheduler"
if not scheduler:
scheduler = "Default"
scheduler = VALUE_SCHEDULER_DEFAULT
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
self.__config[KEY_SCHEDULER] = scheduler

View File

@ -17,6 +17,11 @@ VALUE_STEPS_DEFAULT = 100
KEY_SCHEDULER = "SCHEDULER"
VALUE_SCHEDULER_DEFAULT = "Default"
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True

View File

@ -2,7 +2,7 @@ import gc
import torch
def empty_memory():
def empty_memory_cache():
"""
Performs garbage collection and empty cache in cuda device.
"""

View File

@ -1,10 +1,17 @@
import diffusers
import torch
from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
from utilities.constants import VALUE_SCHEDULER_PNDM
from utilities.logger import DummyLogger
from utilities.memory import empty_memory
from utilities.memory import empty_memory_cache
from utilities.memory import tune_for_low_memory
@ -27,75 +34,133 @@ class Model:
self.__logger = logger
self.__torch_dtype = "auto"
self.sd_pipeline = None
# txt2img and img2img are always loaded together
self.txt2img_pipeline = None
self.img2img_pipeline = None
self.inpaint_pipeline = None
def use_gpu(self):
return self.__use_gpu
def update_model_name(self, model_name:str):
if not model_name or model_name == self.model_name:
self.__logger.warn("model name empty or the same, not updated")
return
self.model_name = model_name
self.load_txt2img_and_img2img_pipeline(force_reload=True)
def reduce_memory(self):
def set_low_memory_mode(self):
self.__logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory()
self.__torch_dtype = torch.float16
def __set_scheduler(self, scheduler:str, pipeline, default_scheduler):
if scheduler == VALUE_SCHEDULER_DEFAULT:
pipeline.scheduler = default_scheduler
return
config = pipeline.scheduler.config
pipeline.scheduler = getattr(diffusers, scheduler).from_config(config)
def load(self):
empty_memory()
empty_memory_cache()
if self.model_name:
revision = get_revision_from_model_name(self.model_name)
sd_pipeline = None
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except:
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except Exception as e:
self.__logger.error(
"failed to load model %s: %s" % (self.model_name, e)
)
if sd_pipeline and self.use_gpu():
sd_pipeline.to("cuda")
self.sd_pipeline = sd_pipeline
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**sd_pipeline.components
def set_img2img_scheduler(self, scheduler:str):
# note the change here also affects txt2img scheduler
if self.img2img_pipeline is None:
self.__logger.error("no img2img pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.img2img_pipeline, self.__default_img2img_scheduler)
def set_txt2img_scheduler(self, scheduler:str):
# note the change here also affects img2img scheduler
if self.txt2img_pipeline is None:
self.__logger.error("no txt2img pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler)
def set_inpaint_scheduler(self, scheduler:str):
if self.inpaint_pipeline is None:
self.__logger.error("no inpaint pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler)
def load_txt2img_and_img2img_pipeline(self, force_reload:bool=False):
if (not force_reload) and (self.txt2img_pipeline is not None):
self.__logger.warn("txt2img and img2img pipelines already loaded")
return
if not self.model_name:
self.__logger.error("unable to load txt2img and img2img pipelines, model not set")
return
revision = get_revision_from_model_name(self.model_name)
pipeline = None
try:
pipeline = StableDiffusionPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
if self.inpainting_model_name:
revision = get_revision_from_model_name(self.inpainting_model_name)
inpaint_pipeline = None
except:
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision=revision,
pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except:
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except Exception as e:
self.__logger.error(
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if inpaint_pipeline and self.use_gpu():
inpaint_pipeline.to("cuda")
self.inpaint_pipeline = inpaint_pipeline
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler
except Exception as e:
self.__logger.error(
"failed to load model %s: %s" % (self.model_name, e)
)
if pipeline and self.use_gpu():
pipeline.to("cuda")
self.txt2img_pipeline = pipeline
self.__default_txt2img_scheduler = pipeline.scheduler
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**pipeline.components
)
self.__default_img2img_scheduler = self.__default_txt2img_scheduler
empty_memory_cache()
def load_inpaint_pipeline(self, force_reload:bool=False):
if (not force_reload) and (self.inpaint_pipeline is not None):
self.__logger.warn("inpaint pipeline already loaded")
return
if not self.inpainting_model_name:
self.__logger.error("unable to load inpaint pipeline, model not set")
return
revision = get_revision_from_model_name(self.inpainting_model_name)
pipeline = None
try:
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except:
try:
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None,
)
except Exception as e:
self.__logger.error(
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if pipeline and self.use_gpu():
pipeline.to("cuda")
self.inpaint_pipeline = pipeline
self.__default_inpaint_scheduler = pipeline.scheduler
empty_memory_cache()
def load_all(self):
self.load_txt2img_and_img2img_pipeline()
self.load_inpaint_pipeline()
def get_revision_from_model_name(model_name: str):

View File

@ -11,10 +11,5 @@ class Text2Img:
self.model = model
self.config = config
def update_config(config: Config):
self.config = config
def update_model(model, Model):
self.model = model
def breakfast(self):
self.model.set_txt2img_scheduler(config.get_scheduler())