adds scheduler config and barebone of text2img
This commit is contained in:
parent
d4e00c3377
commit
41664e2682
4
main.py
4
main.py
|
|
@ -21,8 +21,8 @@ def prepare(logger: Logger) -> [Model, Config]:
|
||||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||||
|
|
||||||
model = Model(model_name, inpainting_model_name, logger)
|
model = Model(model_name, inpainting_model_name, logger)
|
||||||
model.reduce_memory()
|
model.set_low_memory_mode()
|
||||||
model.load()
|
model.load_all()
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
return model, config
|
return model, config
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ py_library(
|
||||||
name="model",
|
name="model",
|
||||||
srcs=["model.py"],
|
srcs=["model.py"],
|
||||||
deps=[
|
deps=[
|
||||||
|
":constants",
|
||||||
":memory",
|
":memory",
|
||||||
":logger",
|
":logger",
|
||||||
],
|
],
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,11 @@ from utilities.constants import KEY_PREVIEW
|
||||||
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
||||||
from utilities.constants import KEY_SCHEDULER
|
from utilities.constants import KEY_SCHEDULER
|
||||||
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_DDIM
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_PNDM
|
||||||
from utilities.constants import KEY_SEED
|
from utilities.constants import KEY_SEED
|
||||||
from utilities.constants import VALUE_SEED_DEFAULT
|
from utilities.constants import VALUE_SEED_DEFAULT
|
||||||
from utilities.constants import KEY_STEPS
|
from utilities.constants import KEY_STEPS
|
||||||
|
|
@ -55,15 +60,8 @@ class Config:
|
||||||
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
||||||
|
|
||||||
def set_scheduler(self, scheduler: str):
|
def set_scheduler(self, scheduler: str):
|
||||||
# choices:
|
|
||||||
# "Default"
|
|
||||||
# "DPMSolverMultistepScheduler"
|
|
||||||
# "LMSDiscreteScheduler"
|
|
||||||
# "EulerDiscreteScheduler"
|
|
||||||
# "PNDMScheduler"
|
|
||||||
# "DDIMScheduler"
|
|
||||||
if not scheduler:
|
if not scheduler:
|
||||||
scheduler = "Default"
|
scheduler = VALUE_SCHEDULER_DEFAULT
|
||||||
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
||||||
self.__config[KEY_SCHEDULER] = scheduler
|
self.__config[KEY_SCHEDULER] = scheduler
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,11 @@ VALUE_STEPS_DEFAULT = 100
|
||||||
|
|
||||||
KEY_SCHEDULER = "SCHEDULER"
|
KEY_SCHEDULER = "SCHEDULER"
|
||||||
VALUE_SCHEDULER_DEFAULT = "Default"
|
VALUE_SCHEDULER_DEFAULT = "Default"
|
||||||
|
VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP = "DPMSolverMultistepScheduler"
|
||||||
|
VALUE_SCHEDULER_LMS_DISCRETE = "LMSDiscreteScheduler"
|
||||||
|
VALUE_SCHEDULER_EULER_DISCRETE = "EulerDiscreteScheduler"
|
||||||
|
VALUE_SCHEDULER_PNDM = "PNDMScheduler"
|
||||||
|
VALUE_SCHEDULER_DDIM = "DDIMScheduler"
|
||||||
|
|
||||||
KEY_PREVIEW = "PREVIEW"
|
KEY_PREVIEW = "PREVIEW"
|
||||||
VALUE_PREVIEW_DEFAULT = True
|
VALUE_PREVIEW_DEFAULT = True
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import gc
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|
||||||
def empty_memory():
|
def empty_memory_cache():
|
||||||
"""
|
"""
|
||||||
Performs garbage collection and empty cache in cuda device.
|
Performs garbage collection and empty cache in cuda device.
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,17 @@
|
||||||
|
import diffusers
|
||||||
import torch
|
import torch
|
||||||
from diffusers import StableDiffusionPipeline
|
from diffusers import StableDiffusionPipeline
|
||||||
from diffusers import StableDiffusionImg2ImgPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline
|
||||||
from diffusers import StableDiffusionInpaintPipeline
|
from diffusers import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_DDIM
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_PNDM
|
||||||
from utilities.logger import DummyLogger
|
from utilities.logger import DummyLogger
|
||||||
from utilities.memory import empty_memory
|
from utilities.memory import empty_memory_cache
|
||||||
from utilities.memory import tune_for_low_memory
|
from utilities.memory import tune_for_low_memory
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -27,26 +34,66 @@ class Model:
|
||||||
self.__logger = logger
|
self.__logger = logger
|
||||||
self.__torch_dtype = "auto"
|
self.__torch_dtype = "auto"
|
||||||
|
|
||||||
self.sd_pipeline = None
|
# txt2img and img2img are always loaded together
|
||||||
|
self.txt2img_pipeline = None
|
||||||
self.img2img_pipeline = None
|
self.img2img_pipeline = None
|
||||||
self.inpaint_pipeline = None
|
self.inpaint_pipeline = None
|
||||||
|
|
||||||
def use_gpu(self):
|
def use_gpu(self):
|
||||||
return self.__use_gpu
|
return self.__use_gpu
|
||||||
|
|
||||||
def reduce_memory(self):
|
def update_model_name(self, model_name:str):
|
||||||
|
if not model_name or model_name == self.model_name:
|
||||||
|
self.__logger.warn("model name empty or the same, not updated")
|
||||||
|
return
|
||||||
|
self.model_name = model_name
|
||||||
|
self.load_txt2img_and_img2img_pipeline(force_reload=True)
|
||||||
|
|
||||||
|
def set_low_memory_mode(self):
|
||||||
self.__logger.info("reduces memory usage by using float16 dtype")
|
self.__logger.info("reduces memory usage by using float16 dtype")
|
||||||
tune_for_low_memory()
|
tune_for_low_memory()
|
||||||
self.__torch_dtype = torch.float16
|
self.__torch_dtype = torch.float16
|
||||||
|
|
||||||
def load(self):
|
def __set_scheduler(self, scheduler:str, pipeline, default_scheduler):
|
||||||
empty_memory()
|
if scheduler == VALUE_SCHEDULER_DEFAULT:
|
||||||
|
pipeline.scheduler = default_scheduler
|
||||||
|
return
|
||||||
|
config = pipeline.scheduler.config
|
||||||
|
pipeline.scheduler = getattr(diffusers, scheduler).from_config(config)
|
||||||
|
|
||||||
if self.model_name:
|
empty_memory_cache()
|
||||||
|
|
||||||
|
def set_img2img_scheduler(self, scheduler:str):
|
||||||
|
# note the change here also affects txt2img scheduler
|
||||||
|
if self.img2img_pipeline is None:
|
||||||
|
self.__logger.error("no img2img pipeline loaded, unable to set scheduler")
|
||||||
|
return
|
||||||
|
self.__set_scheduler(scheduler, self.img2img_pipeline, self.__default_img2img_scheduler)
|
||||||
|
|
||||||
|
def set_txt2img_scheduler(self, scheduler:str):
|
||||||
|
# note the change here also affects img2img scheduler
|
||||||
|
if self.txt2img_pipeline is None:
|
||||||
|
self.__logger.error("no txt2img pipeline loaded, unable to set scheduler")
|
||||||
|
return
|
||||||
|
self.__set_scheduler(scheduler, self.txt2img_pipeline, self.__default_txt2img_scheduler)
|
||||||
|
|
||||||
|
def set_inpaint_scheduler(self, scheduler:str):
|
||||||
|
if self.inpaint_pipeline is None:
|
||||||
|
self.__logger.error("no inpaint pipeline loaded, unable to set scheduler")
|
||||||
|
return
|
||||||
|
self.__set_scheduler(scheduler, self.inpaint_pipeline, self.__default_inpaint_scheduler)
|
||||||
|
|
||||||
|
def load_txt2img_and_img2img_pipeline(self, force_reload:bool=False):
|
||||||
|
if (not force_reload) and (self.txt2img_pipeline is not None):
|
||||||
|
self.__logger.warn("txt2img and img2img pipelines already loaded")
|
||||||
|
return
|
||||||
|
if not self.model_name:
|
||||||
|
self.__logger.error("unable to load txt2img and img2img pipelines, model not set")
|
||||||
|
return
|
||||||
revision = get_revision_from_model_name(self.model_name)
|
revision = get_revision_from_model_name(self.model_name)
|
||||||
sd_pipeline = None
|
pipeline = None
|
||||||
try:
|
try:
|
||||||
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
|
|
@ -54,7 +101,7 @@ class Model:
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
|
@ -63,18 +110,30 @@ class Model:
|
||||||
self.__logger.error(
|
self.__logger.error(
|
||||||
"failed to load model %s: %s" % (self.model_name, e)
|
"failed to load model %s: %s" % (self.model_name, e)
|
||||||
)
|
)
|
||||||
if sd_pipeline and self.use_gpu():
|
if pipeline and self.use_gpu():
|
||||||
sd_pipeline.to("cuda")
|
pipeline.to("cuda")
|
||||||
self.sd_pipeline = sd_pipeline
|
|
||||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
|
|
||||||
**sd_pipeline.components
|
|
||||||
)
|
|
||||||
|
|
||||||
if self.inpainting_model_name:
|
self.txt2img_pipeline = pipeline
|
||||||
|
self.__default_txt2img_scheduler = pipeline.scheduler
|
||||||
|
|
||||||
|
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
|
||||||
|
**pipeline.components
|
||||||
|
)
|
||||||
|
self.__default_img2img_scheduler = self.__default_txt2img_scheduler
|
||||||
|
|
||||||
|
empty_memory_cache()
|
||||||
|
|
||||||
|
def load_inpaint_pipeline(self, force_reload:bool=False):
|
||||||
|
if (not force_reload) and (self.inpaint_pipeline is not None):
|
||||||
|
self.__logger.warn("inpaint pipeline already loaded")
|
||||||
|
return
|
||||||
|
if not self.inpainting_model_name:
|
||||||
|
self.__logger.error("unable to load inpaint pipeline, model not set")
|
||||||
|
return
|
||||||
revision = get_revision_from_model_name(self.inpainting_model_name)
|
revision = get_revision_from_model_name(self.inpainting_model_name)
|
||||||
inpaint_pipeline = None
|
pipeline = None
|
||||||
try:
|
try:
|
||||||
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
model_name,
|
model_name,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
|
|
@ -82,7 +141,7 @@ class Model:
|
||||||
)
|
)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
self.inpainting_model_name,
|
self.inpainting_model_name,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
safety_checker=None,
|
safety_checker=None,
|
||||||
|
|
@ -92,10 +151,16 @@ class Model:
|
||||||
"failed to load inpaint model %s: %s"
|
"failed to load inpaint model %s: %s"
|
||||||
% (self.inpainting_model_name, e)
|
% (self.inpainting_model_name, e)
|
||||||
)
|
)
|
||||||
if inpaint_pipeline and self.use_gpu():
|
if pipeline and self.use_gpu():
|
||||||
inpaint_pipeline.to("cuda")
|
pipeline.to("cuda")
|
||||||
self.inpaint_pipeline = inpaint_pipeline
|
self.inpaint_pipeline = pipeline
|
||||||
self.inpaint_pipeline_scheduler = inpaint_pipeline.scheduler
|
self.__default_inpaint_scheduler = pipeline.scheduler
|
||||||
|
|
||||||
|
empty_memory_cache()
|
||||||
|
|
||||||
|
def load_all(self):
|
||||||
|
self.load_txt2img_and_img2img_pipeline()
|
||||||
|
self.load_inpaint_pipeline()
|
||||||
|
|
||||||
|
|
||||||
def get_revision_from_model_name(model_name: str):
|
def get_revision_from_model_name(model_name: str):
|
||||||
|
|
|
||||||
|
|
@ -11,10 +11,5 @@ class Text2Img:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
def update_config(config: Config):
|
def breakfast(self):
|
||||||
self.config = config
|
self.model.set_txt2img_scheduler(config.get_scheduler())
|
||||||
|
|
||||||
def update_model(model, Model):
|
|
||||||
self.model = model
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue