adds scheduler config and barebone of text2img

This commit is contained in:
HappyZ 2023-04-27 00:39:02 -07:00
parent d4e00c3377
commit 41664e2682
7 changed files with 136 additions and 72 deletions

View File

@ -21,8 +21,8 @@ def prepare(logger: Logger) -> [Model, Config]:
inpainting_model_name = "runwayml/stable-diffusion-inpainting" inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger) model = Model(model_name, inpainting_model_name, logger)
model.reduce_memory() model.set_low_memory_mode()
model.load() model.load_all()
config = Config() config = Config()
return model, config return model, config

View File

@ -34,6 +34,7 @@ py_library(
name="model", name="model",
srcs=["model.py"], srcs=["model.py"],
deps=[ deps=[
":constants",
":memory", ":memory",
":logger", ":logger",
], ],

View File

@ -9,6 +9,11 @@ from utilities.constants import KEY_PREVIEW
from utilities.constants import VALUE_PREVIEW_DEFAULT from utilities.constants import VALUE_PREVIEW_DEFAULT
from utilities.constants import KEY_SCHEDULER from utilities.constants import KEY_SCHEDULER
from utilities.constants import VALUE_SCHEDULER_DEFAULT from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
from utilities.constants import VALUE_SCHEDULER_PNDM
from utilities.constants import KEY_SEED from utilities.constants import KEY_SEED
from utilities.constants import VALUE_SEED_DEFAULT from utilities.constants import VALUE_SEED_DEFAULT
from utilities.constants import KEY_STEPS from utilities.constants import KEY_STEPS
@ -55,15 +60,8 @@ class Config:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT) return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
def set_scheduler(self, scheduler: str): def set_scheduler(self, scheduler: str):
# choices:
# "Default"
# "DPMSolverMultistepScheduler"
# "LMSDiscreteScheduler"
# "EulerDiscreteScheduler"
# "PNDMScheduler"
# "DDIMScheduler"
if not scheduler: if not scheduler:
scheduler = "Default" scheduler = VALUE_SCHEDULER_DEFAULT
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler)) self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
self.__config[KEY_SCHEDULER] = scheduler self.__config[KEY_SCHEDULER] = scheduler

View File

@ -17,6 +17,11 @@ VALUE_STEPS_DEFAULT = 100
KEY_SCHEDULER = "SCHEDULER" KEY_SCHEDULER = "SCHEDULER"
VALUE_SCHEDULER_DEFAULT = "Default" VALUE_SCHEDULER_DEFAULT = "Default"
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
KEY_PREVIEW = "PREVIEW" KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True VALUE_PREVIEW_DEFAULT = True

View File

@ -2,7 +2,7 @@ import gc
import torch import torch
def empty_memory(): def empty_memory_cache():
""" """
Performs garbage collection and empty cache in cuda device. Performs garbage collection and empty cache in cuda device.
""" """

View File

@ -1,10 +1,17 @@
import diffusers
import torch import torch
from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline from diffusers import StableDiffusionInpaintPipeline
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import VALUE_SCHEDULER_DDIM
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
from utilities.constants import VALUE_SCHEDULER_PNDM
from utilities.logger import DummyLogger from utilities.logger import DummyLogger
from utilities.memory import empty_memory from utilities.memory import empty_memory_cache
from utilities.memory import tune_for_low_memory from utilities.memory import tune_for_low_memory
@ -27,26 +34,66 @@ class Model:
self.__logger = logger self.__logger = logger
self.__torch_dtype = "auto" self.__torch_dtype = "auto"
self.sd_pipeline = None # txt2img and img2img are always loaded together
self.txt2img_pipeline = None
self.img2img_pipeline = None self.img2img_pipeline = None
self.inpaint_pipeline = None self.inpaint_pipeline = None
def use_gpu(self): def use_gpu(self):
return self.__use_gpu return self.__use_gpu
def reduce_memory(self): def update_model_name(self, model_name:str):
if not model_name or model_name == self.model_name:
self.__logger.warn("model name empty or the same, not updated")
return
self.model_name = model_name
self.load_txt2img_and_img2img_pipeline(force_reload=True)
def set_low_memory_mode(self):
self.__logger.info("reduces memory usage by using float16 dtype") self.__logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory() tune_for_low_memory()
self.__torch_dtype = torch.float16 self.__torch_dtype = torch.float16
def load(self): def __set_scheduler(self, scheduler:str, pipeline, default_scheduler):
empty_memory() if scheduler == VALUE_SCHEDULER_DEFAULT:
pipeline.scheduler = default_scheduler
return
config = pipeline.scheduler.config
pipeline.scheduler = getattr(diffusers, scheduler).from_config(config)
if self.model_name: empty_memory_cache()
def set_img2img_scheduler(self, scheduler:str):
# note the change here also affects txt2img scheduler
if self.img2img_pipeline is None:
self.__logger.error("no img2img pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.img2img_pipeline, self.__default_img2img_scheduler)
def set_txt2img_scheduler(self, scheduler:str):
# note the change here also affects img2img scheduler
if self.txt2img_pipeline is None:
self.__logger.error("no txt2img pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler)
def set_inpaint_scheduler(self, scheduler:str):
if self.inpaint_pipeline is None:
self.__logger.error("no inpaint pipeline loaded, unable to set scheduler")
return
self.__set_scheduler(scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler)
def load_txt2img_and_img2img_pipeline(self, force_reload:bool=False):
if (not force_reload) and (self.txt2img_pipeline is not None):
self.__logger.warn("txt2img and img2img pipelines already loaded")
return
if not self.model_name:
self.__logger.error("unable to load txt2img and img2img pipelines, model not set")
return
revision = get_revision_from_model_name(self.model_name) revision = get_revision_from_model_name(self.model_name)
sd_pipeline = None pipeline = None
try: try:
sd_pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
model_name, model_name,
revision=revision, revision=revision,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
@ -54,7 +101,7 @@ class Model:
) )
except: except:
try: try:
sd_pipeline = StableDiffusionPipeline.from_pretrained( pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name, self.model_name,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None, safety_checker=None,
@ -63,18 +110,30 @@ class Model:
self.__logger.error( self.__logger.error(
"failed to load model %s: %s" % (self.model_name, e) "failed to load model %s: %s" % (self.model_name, e)
) )
if sd_pipeline and self.use_gpu(): if pipeline and self.use_gpu():
sd_pipeline.to("cuda") pipeline.to("cuda")
self.sd_pipeline = sd_pipeline
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**sd_pipeline.components
)
if self.inpainting_model_name: self.txt2img_pipeline = pipeline
self.__default_txt2img_scheduler = pipeline.scheduler
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**pipeline.components
)
self.__default_img2img_scheduler = self.__default_txt2img_scheduler
empty_memory_cache()
def load_inpaint_pipeline(self, force_reload:bool=False):
if (not force_reload) and (self.inpaint_pipeline is not None):
self.__logger.warn("inpaint pipeline already loaded")
return
if not self.inpainting_model_name:
self.__logger.error("unable to load inpaint pipeline, model not set")
return
revision = get_revision_from_model_name(self.inpainting_model_name) revision = get_revision_from_model_name(self.inpainting_model_name)
inpaint_pipeline = None pipeline = None
try: try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( pipeline = StableDiffusionInpaintPipeline.from_pretrained(
model_name, model_name,
revision=revision, revision=revision,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
@ -82,7 +141,7 @@ class Model:
) )
except: except:
try: try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name, self.inpainting_model_name,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None, safety_checker=None,
@ -92,10 +151,16 @@ class Model:
"failed to load inpaint model %s: %s" "failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e) % (self.inpainting_model_name, e)
) )
if inpaint_pipeline and self.use_gpu(): if pipeline and self.use_gpu():
inpaint_pipeline.to("cuda") pipeline.to("cuda")
self.inpaint_pipeline = inpaint_pipeline self.inpaint_pipeline = pipeline
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler self.__default_inpaint_scheduler = pipeline.scheduler
empty_memory_cache()
def load_all(self):
self.load_txt2img_and_img2img_pipeline()
self.load_inpaint_pipeline()
def get_revision_from_model_name(model_name: str): def get_revision_from_model_name(model_name: str):

View File

@ -11,10 +11,5 @@ class Text2Img:
self.model = model self.model = model
self.config = config self.config = config
def update_config(config: Config): def breakfast(self):
self.config = config self.model.set_txt2img_scheduler(config.get_scheduler())
def update_model(model, Model):
self.model = model