adds config and text2img barebone, python formats
This commit is contained in:
parent
e40a7c02f6
commit
d4e00c3377
6
BUILD
6
BUILD
|
|
@ -4,10 +4,12 @@ load("@subpar//:subpar.bzl", "par_binary")
|
||||||
package(default_visibility=["//visibility:public"])
|
package(default_visibility=["//visibility:public"])
|
||||||
|
|
||||||
par_binary(
|
par_binary(
|
||||||
name = 'main',
|
name="main",
|
||||||
srcs=["main.py"],
|
srcs=["main.py"],
|
||||||
deps=[
|
deps=[
|
||||||
|
"//utilities:constants",
|
||||||
"//utilities:logger",
|
"//utilities:logger",
|
||||||
"//utilities:memory",
|
"//utilities:model",
|
||||||
|
"//utilities:text2img",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
|
||||||
28
main.py
28
main.py
|
|
@ -1,20 +1,38 @@
|
||||||
from utilities.model import Model
|
from utilities.constants import LOGGER_NAME
|
||||||
from utilities.logger import Logger
|
from utilities.logger import Logger
|
||||||
|
from utilities.model import Model
|
||||||
|
from utilities.config import Config
|
||||||
|
from utilities.text2img import Text2Img
|
||||||
|
|
||||||
def prepare(logger: Logger):
|
|
||||||
|
def prepare(logger: Logger) -> [Model, Config]:
|
||||||
|
# model candidates:
|
||||||
|
# "runwayml/stable-diffusion-v1-5"
|
||||||
|
# "CompVis/stable-diffusion-v1-4"
|
||||||
|
# "stabilityai/stable-diffusion-2-1"
|
||||||
|
# "SG161222/Realistic_Vision_V2.0"
|
||||||
|
# "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||||
|
# "prompthero/openjourney"
|
||||||
|
# "naclbit/trinart_stable_diffusion_v2"
|
||||||
|
# "hakurei/waifu-diffusion"
|
||||||
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
|
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||||
|
# inpainting model candidates:
|
||||||
|
# "runwayml/stable-diffusion-inpainting"
|
||||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||||
|
|
||||||
model = Model(model_name, inpainting_model_name, logger)
|
model = Model(model_name, inpainting_model_name, logger)
|
||||||
model.reduce_memory()
|
model.reduce_memory()
|
||||||
model.load()
|
model.load()
|
||||||
return model
|
|
||||||
|
config = Config()
|
||||||
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
logger = Logger(name="rl_trader")
|
logger = Logger(name=LOGGER_NAME)
|
||||||
|
|
||||||
model = prepare(logger)
|
model, config = prepare(logger)
|
||||||
|
text2img = Text2Img(model, config)
|
||||||
|
|
||||||
input("confirm...")
|
input("confirm...")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,29 @@ py_library(
|
||||||
srcs=["memory.py"],
|
srcs=["memory.py"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="text2img",
|
||||||
|
srcs=["text2img.py"],
|
||||||
|
deps=[
|
||||||
|
":config",
|
||||||
|
":model",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="config",
|
||||||
|
srcs=["config.py"],
|
||||||
|
deps=[
|
||||||
|
":logger",
|
||||||
|
":constants",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="constants",
|
||||||
|
srcs=["constants.py"],
|
||||||
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name="model",
|
name="model",
|
||||||
srcs=["model.py"],
|
srcs=["model.py"],
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,93 @@
|
||||||
|
import random
|
||||||
|
import time
|
||||||
|
|
||||||
|
from utilities.constants import KEY_GUIDANCE_SCALE
|
||||||
|
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
|
||||||
|
from utilities.constants import KEY_HEIGHT
|
||||||
|
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
||||||
|
from utilities.constants import KEY_PREVIEW
|
||||||
|
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
||||||
|
from utilities.constants import KEY_SCHEDULER
|
||||||
|
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||||
|
from utilities.constants import KEY_SEED
|
||||||
|
from utilities.constants import VALUE_SEED_DEFAULT
|
||||||
|
from utilities.constants import KEY_STEPS
|
||||||
|
from utilities.constants import VALUE_STEPS_DEFAULT
|
||||||
|
from utilities.constants import KEY_WIDTH
|
||||||
|
from utilities.constants import VALUE_WIDTH_DEFAULT
|
||||||
|
from utilities.logger import DummyLogger
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""
|
||||||
|
Configuration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, logger: DummyLogger = DummyLogger()):
|
||||||
|
self.__logger = logger
|
||||||
|
self.__config = {}
|
||||||
|
|
||||||
|
def get_config(self) -> dict:
|
||||||
|
return self.__config
|
||||||
|
|
||||||
|
def get_guidance_scale(self) -> float:
|
||||||
|
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
||||||
|
|
||||||
|
def set_guidance_scale(self, scale: float):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
|
||||||
|
self.__config[KEY_GUIDANCE_SCALE] = scale
|
||||||
|
|
||||||
|
def get_height(self) -> int:
|
||||||
|
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
|
||||||
|
|
||||||
|
def set_height(self, value: int):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
|
||||||
|
self.__config[KEY_HEIGHT] = value
|
||||||
|
|
||||||
|
def get_preview(self) -> bool:
|
||||||
|
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
||||||
|
|
||||||
|
def set_preview(self, boolean: bool):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
|
||||||
|
self.__config[KEY_PREVIEW] = boolean
|
||||||
|
|
||||||
|
def get_scheduler(self) -> str:
|
||||||
|
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
||||||
|
|
||||||
|
def set_scheduler(self, scheduler: str):
|
||||||
|
# choices:
|
||||||
|
# "Default"
|
||||||
|
# "DPMSolverMultistepScheduler"
|
||||||
|
# "LMSDiscreteScheduler"
|
||||||
|
# "EulerDiscreteScheduler"
|
||||||
|
# "PNDMScheduler"
|
||||||
|
# "DDIMScheduler"
|
||||||
|
if not scheduler:
|
||||||
|
scheduler = "Default"
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
||||||
|
self.__config[KEY_SCHEDULER] = scheduler
|
||||||
|
|
||||||
|
def get_seed(self) -> int:
|
||||||
|
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
|
||||||
|
if seed == 0:
|
||||||
|
random.seed(int(time.time_ns()))
|
||||||
|
seed = random.getrandbits(64)
|
||||||
|
return seed
|
||||||
|
|
||||||
|
def set_seed(self, seed: int):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
|
||||||
|
self.__config[KEY_SEED] = seed
|
||||||
|
|
||||||
|
def get_steps(self) -> int:
|
||||||
|
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
|
||||||
|
|
||||||
|
def set_steps(self, steps: int):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
|
||||||
|
self.__config[KEY_STEPS] = steps
|
||||||
|
|
||||||
|
def get_width(self) -> int:
|
||||||
|
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
|
||||||
|
|
||||||
|
def set_width(self, value: int):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
|
||||||
|
self.__config[KEY_WIDTH] = value
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
LOGGER_NAME = "main"
|
||||||
|
|
||||||
|
KEY_SEED = "SEED"
|
||||||
|
VALUE_SEED_DEFAULT = 0
|
||||||
|
|
||||||
|
KEY_WIDTH = "WIDTH"
|
||||||
|
VALUE_WIDTH_DEFAULT = 512
|
||||||
|
|
||||||
|
KEY_HEIGHT = "HEIGHT"
|
||||||
|
VALUE_HEIGHT_DEFAULT = 512
|
||||||
|
|
||||||
|
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
||||||
|
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
|
||||||
|
|
||||||
|
KEY_STEPS = "STEPS"
|
||||||
|
VALUE_STEPS_DEFAULT = 100
|
||||||
|
|
||||||
|
KEY_SCHEDULER = "SCHEDULER"
|
||||||
|
VALUE_SCHEDULER_DEFAULT = "Default"
|
||||||
|
|
||||||
|
KEY_PREVIEW = "PREVIEW"
|
||||||
|
VALUE_PREVIEW_DEFAULT = True
|
||||||
|
|
@ -6,7 +6,8 @@ from colorlog import ColoredFormatter
|
||||||
|
|
||||||
|
|
||||||
FORMATTER_STREAM = ColoredFormatter(
|
FORMATTER_STREAM = ColoredFormatter(
|
||||||
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s")
|
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s"
|
||||||
|
)
|
||||||
FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s")
|
FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s")
|
||||||
|
|
||||||
VERBOSITY_V = 10
|
VERBOSITY_V = 10
|
||||||
|
|
@ -14,21 +15,21 @@ VERBOSITY_VV = 100
|
||||||
VERBOSITY_VVV = 1000
|
VERBOSITY_VVV = 1000
|
||||||
|
|
||||||
|
|
||||||
def touch(filepath):
|
def touch(filepath: str):
|
||||||
'''
|
"""
|
||||||
Behaves similarly as `touch` command in Linux system.
|
Behaves similarly as `touch` command in Linux system.
|
||||||
Creates an empty file.
|
Creates an empty file.
|
||||||
'''
|
"""
|
||||||
try:
|
try:
|
||||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||||
open(filepath, 'a').close()
|
open(filepath, "a").close()
|
||||||
return True
|
return True
|
||||||
except BaseException:
|
except BaseException:
|
||||||
pass
|
pass
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_func_name(offset=0):
|
def get_func_name(offset: int = 0):
|
||||||
"""
|
"""
|
||||||
Returns the name of the caller function name.
|
Returns the name of the caller function name.
|
||||||
Offset counts the recursion - for example, offset=1 means the caller of the caller.
|
Offset counts the recursion - for example, offset=1 means the caller of the caller.
|
||||||
|
|
@ -36,10 +37,10 @@ def get_func_name(offset=0):
|
||||||
return sys._getframe(1 + offset).f_code.co_name
|
return sys._getframe(1 + offset).f_code.co_name
|
||||||
|
|
||||||
|
|
||||||
class DummyLogger():
|
class DummyLogger:
|
||||||
'''
|
"""
|
||||||
DummyLogger does not do anything.
|
DummyLogger does not do anything.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
pass
|
||||||
|
|
@ -59,26 +60,26 @@ class DummyLogger():
|
||||||
def debugging_off(self):
|
def debugging_off(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def info(self, msg=""):
|
def info(self, msg: str = ""):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def error(self, msg=""):
|
def error(self, msg: str = ""):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def warn(self, msg=""):
|
def warn(self, msg: str = ""):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def debug(self, msg="", verbosity=VERBOSITY_V):
|
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def critical(self, msg=""):
|
def critical(self, msg: str = ""):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
class Logger(DummyLogger):
|
class Logger(DummyLogger):
|
||||||
'''
|
"""
|
||||||
Logger with specific format
|
Logger with specific format
|
||||||
'''
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
@ -86,7 +87,7 @@ class Logger(DummyLogger):
|
||||||
filepath: str = "",
|
filepath: str = "",
|
||||||
verbosity: int = VERBOSITY_V,
|
verbosity: int = VERBOSITY_V,
|
||||||
stream_lvl=logging.INFO,
|
stream_lvl=logging.INFO,
|
||||||
file_lvl=logging.DEBUG
|
file_lvl=logging.DEBUG,
|
||||||
):
|
):
|
||||||
self.verbosity = verbosity
|
self.verbosity = verbosity
|
||||||
self.logger = logging.getLogger(name)
|
self.logger = logging.getLogger(name)
|
||||||
|
|
@ -129,53 +130,54 @@ class Logger(DummyLogger):
|
||||||
def debugging_on(self):
|
def debugging_on(self):
|
||||||
self.__streamHandler.setLevel(logging.DEBUG)
|
self.__streamHandler.setLevel(logging.DEBUG)
|
||||||
self.logger.info(
|
self.logger.info(
|
||||||
"debug msg printing is on, verbosity: {}".format(self.verbosity))
|
"debug msg printing is on, verbosity: {}".format(self.verbosity)
|
||||||
|
)
|
||||||
|
|
||||||
def debugging_off(self):
|
def debugging_off(self):
|
||||||
self.__streamHandler.setLevel(logging.INFO)
|
self.__streamHandler.setLevel(logging.INFO)
|
||||||
self.logger.info("debug msg printing is off")
|
self.logger.info("debug msg printing is off")
|
||||||
|
|
||||||
def info(self, msg="", verbosity: int = VERBOSITY_V):
|
def info(self, msg="", verbosity: int = VERBOSITY_V):
|
||||||
'''
|
"""
|
||||||
Showing info message.
|
Showing info message.
|
||||||
@param msg: the message
|
@param msg: the message
|
||||||
@param verbosity: the higher the number is, the less important it is.
|
@param verbosity: the higher the number is, the less important it is.
|
||||||
'''
|
"""
|
||||||
if verbosity > self.verbosity:
|
if verbosity > self.verbosity:
|
||||||
return
|
return
|
||||||
self.logger.info("[{}] {}".format(get_func_name(offset=1), msg))
|
self.logger.info("[{}] {}".format(get_func_name(offset=1), msg))
|
||||||
|
|
||||||
def warn(self, msg="", verbosity: int = VERBOSITY_V):
|
def warn(self, msg="", verbosity: int = VERBOSITY_V):
|
||||||
'''
|
"""
|
||||||
Showing warning message.
|
Showing warning message.
|
||||||
@param msg: the message
|
@param msg: the message
|
||||||
@param verbosity: the higher the number is, the less important it is.
|
@param verbosity: the higher the number is, the less important it is.
|
||||||
'''
|
"""
|
||||||
if verbosity > self.verbosity:
|
if verbosity > self.verbosity:
|
||||||
return
|
return
|
||||||
self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg))
|
self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg))
|
||||||
|
|
||||||
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
|
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
|
||||||
'''
|
"""
|
||||||
Showing debug message.
|
Showing debug message.
|
||||||
@param msg: the message
|
@param msg: the message
|
||||||
@param verbosity: the higher the number is, the less important it is.
|
@param verbosity: the higher the number is, the less important it is.
|
||||||
'''
|
"""
|
||||||
if verbosity > self.verbosity:
|
if verbosity > self.verbosity:
|
||||||
return
|
return
|
||||||
self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg))
|
self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg))
|
||||||
|
|
||||||
def error(self, msg=""):
|
def error(self, msg=""):
|
||||||
'''
|
"""
|
||||||
Showing error message.
|
Showing error message.
|
||||||
@param msg: the message
|
@param msg: the message
|
||||||
'''
|
"""
|
||||||
self.logger.error("[{}] {}".format(get_func_name(offset=1), msg))
|
self.logger.error("[{}] {}".format(get_func_name(offset=1), msg))
|
||||||
self.logger.error(traceback.format_exc())
|
self.logger.error(traceback.format_exc())
|
||||||
|
|
||||||
def critical(self, msg=""):
|
def critical(self, msg=""):
|
||||||
'''
|
"""
|
||||||
Showing critical message.
|
Showing critical message.
|
||||||
@param msg: the message
|
@param msg: the message
|
||||||
'''
|
"""
|
||||||
self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg))
|
self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg))
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ from utilities.logger import Logger
|
||||||
|
|
||||||
|
|
||||||
class TestLogger(unittest.TestCase):
|
class TestLogger(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(self):
|
def setUpClass(self):
|
||||||
self.logger = Logger()
|
self.logger = Logger()
|
||||||
|
|
@ -31,5 +30,5 @@ class TestLogger(unittest.TestCase):
|
||||||
os.remove(self.log_filepath)
|
os.remove(self.log_filepath)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
|
||||||
|
|
@ -3,15 +3,15 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
def empty_memory():
|
def empty_memory():
|
||||||
'''
|
"""
|
||||||
Performs garbage collection and empty cache in cuda device.
|
Performs garbage collection and empty cache in cuda device.
|
||||||
'''
|
"""
|
||||||
gc.collect()
|
gc.collect()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
|
||||||
def tune_for_low_memory():
|
def tune_for_low_memory():
|
||||||
'''
|
"""
|
||||||
Tunes PyTorch to use float16 to reduce memory footprint.
|
Tunes PyTorch to use float16 to reduce memory footprint.
|
||||||
'''
|
"""
|
||||||
torch.set_default_dtype(torch.float16)
|
torch.set_default_dtype(torch.float16)
|
||||||
|
|
|
||||||
|
|
@ -3,30 +3,39 @@ from diffusers import StableDiffusionPipeline
|
||||||
from diffusers import StableDiffusionImg2ImgPipeline
|
from diffusers import StableDiffusionImg2ImgPipeline
|
||||||
from diffusers import StableDiffusionInpaintPipeline
|
from diffusers import StableDiffusionInpaintPipeline
|
||||||
|
|
||||||
from utilities.logger import Logger
|
from utilities.logger import DummyLogger
|
||||||
from utilities.memory import empty_memory
|
from utilities.memory import empty_memory
|
||||||
from utilities.memory import tune_for_low_memory
|
from utilities.memory import tune_for_low_memory
|
||||||
|
|
||||||
class Model():
|
|
||||||
'''Model class.'''
|
class Model:
|
||||||
def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True):
|
"""Model class."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_name: str,
|
||||||
|
inpainting_model_name: str,
|
||||||
|
logger: DummyLogger = DummyLogger(),
|
||||||
|
use_gpu: bool = True,
|
||||||
|
):
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.inpainting_model_name = inpainting_model_name
|
self.inpainting_model_name = inpainting_model_name
|
||||||
self.__use_gpu = False
|
self.__use_gpu = False
|
||||||
if use_gpu and torch.cuda.is_available():
|
if use_gpu and torch.cuda.is_available():
|
||||||
self.__use_gpu = True
|
self.__use_gpu = True
|
||||||
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
||||||
self.logger = logger
|
self.__logger = logger
|
||||||
|
self.__torch_dtype = "auto"
|
||||||
|
|
||||||
self.sd_pipeline = None
|
self.sd_pipeline = None
|
||||||
self.img2img_pipeline = None
|
self.img2img_pipeline = None
|
||||||
self.inpaint_pipeline = None
|
self.inpaint_pipeline = None
|
||||||
self.__torch_dtype = "auto"
|
|
||||||
|
|
||||||
def use_gpu(self):
|
def use_gpu(self):
|
||||||
return self.__use_gpu
|
return self.__use_gpu
|
||||||
|
|
||||||
def reduce_memory(self):
|
def reduce_memory(self):
|
||||||
self.logger.info("reduces memory usage by using float16 dtype")
|
self.__logger.info("reduces memory usage by using float16 dtype")
|
||||||
tune_for_low_memory()
|
tune_for_low_memory()
|
||||||
self.__torch_dtype = torch.float16
|
self.__torch_dtype = torch.float16
|
||||||
|
|
||||||
|
|
@ -41,21 +50,25 @@ class Model():
|
||||||
model_name,
|
model_name,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
safety_checker=None)
|
safety_checker=None,
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||||
self.model_name,
|
self.model_name,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
safety_checker=None)
|
safety_checker=None,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error("failed to load model %s: %s" % (self.model_name, e))
|
self.__logger.error(
|
||||||
|
"failed to load model %s: %s" % (self.model_name, e)
|
||||||
|
)
|
||||||
if sd_pipeline and self.use_gpu():
|
if sd_pipeline and self.use_gpu():
|
||||||
sd_pipeline.to("cuda")
|
sd_pipeline.to("cuda")
|
||||||
self.sd_pipeline = sd_pipeline
|
self.sd_pipeline = sd_pipeline
|
||||||
self.sd_pipeline_scheduler = sd_pipeline.scheduler
|
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
|
||||||
|
**sd_pipeline.components
|
||||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components)
|
)
|
||||||
|
|
||||||
if self.inpainting_model_name:
|
if self.inpainting_model_name:
|
||||||
revision = get_revision_from_model_name(self.inpainting_model_name)
|
revision = get_revision_from_model_name(self.inpainting_model_name)
|
||||||
|
|
@ -65,15 +78,20 @@ class Model():
|
||||||
model_name,
|
model_name,
|
||||||
revision=revision,
|
revision=revision,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
safety_checker=None)
|
safety_checker=None,
|
||||||
|
)
|
||||||
except:
|
except:
|
||||||
try:
|
try:
|
||||||
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||||
self.inpainting_model_name,
|
self.inpainting_model_name,
|
||||||
torch_dtype=self.__torch_dtype,
|
torch_dtype=self.__torch_dtype,
|
||||||
safety_checker=None)
|
safety_checker=None,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e))
|
self.__logger.error(
|
||||||
|
"failed to load inpaint model %s: %s"
|
||||||
|
% (self.inpainting_model_name, e)
|
||||||
|
)
|
||||||
if inpaint_pipeline and self.use_gpu():
|
if inpaint_pipeline and self.use_gpu():
|
||||||
inpaint_pipeline.to("cuda")
|
inpaint_pipeline.to("cuda")
|
||||||
self.inpaint_pipeline = inpaint_pipeline
|
self.inpaint_pipeline = inpaint_pipeline
|
||||||
|
|
@ -81,4 +99,8 @@ class Model():
|
||||||
|
|
||||||
|
|
||||||
def get_revision_from_model_name(model_name: str):
|
def get_revision_from_model_name(model_name: str):
|
||||||
return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
|
return (
|
||||||
|
"diffusers-115k"
|
||||||
|
if model_name == "naclbit/trinart_stable_diffusion_v2"
|
||||||
|
else "fp16"
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
from utilities.config import Config
|
||||||
|
from utilities.model import Model
|
||||||
|
|
||||||
|
|
||||||
|
class Text2Img:
|
||||||
|
"""
|
||||||
|
Text2Img class.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, model: Model, config: Config):
|
||||||
|
self.model = model
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def update_config(config: Config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def update_model(model, Model):
|
||||||
|
self.model = model
|
||||||
|
|
||||||
|
|
||||||
Loading…
Reference in New Issue