From d4e00c33776f49e97966576a1458d5b47b9fac2e Mon Sep 17 00:00:00 2001 From: HappyZ Date: Wed, 26 Apr 2023 23:07:24 -0700 Subject: [PATCH] adds config and text2img barebone, python formats --- BUILD | 16 ++++--- main.py | 30 ++++++++++--- utilities/BUILD | 51 ++++++++++++++++------ utilities/config.py | 93 ++++++++++++++++++++++++++++++++++++++++ utilities/constants.py | 22 ++++++++++ utilities/logger.py | 60 +++++++++++++------------- utilities/logger_test.py | 3 +- utilities/memory.py | 8 ++-- utilities/model.py | 56 ++++++++++++++++-------- utilities/text2img.py | 20 +++++++++ 10 files changed, 280 insertions(+), 79 deletions(-) create mode 100644 utilities/config.py create mode 100644 utilities/constants.py create mode 100644 utilities/text2img.py diff --git a/BUILD b/BUILD index 003f75c..bf5105a 100644 --- a/BUILD +++ b/BUILD @@ -1,13 +1,15 @@ load("@rules_python//python:defs.bzl", "py_binary") load("@subpar//:subpar.bzl", "par_binary") -package(default_visibility = ["//visibility:public"]) +package(default_visibility=["//visibility:public"]) par_binary( - name = 'main', - srcs = ["main.py"], - deps = [ - "//utilities:logger", - "//utilities:memory", - ], + name="main", + srcs=["main.py"], + deps=[ + "//utilities:constants", + "//utilities:logger", + "//utilities:model", + "//utilities:text2img", + ], ) diff --git a/main.py b/main.py index 22b8c17..f2d9d31 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,39 @@ -from utilities.model import Model +from utilities.constants import LOGGER_NAME from utilities.logger import Logger +from utilities.model import Model +from utilities.config import Config +from utilities.text2img import Text2Img -def prepare(logger: Logger): + +def prepare(logger: Logger) -> [Model, Config]: + # model candidates: + # "runwayml/stable-diffusion-v1-5" + # "CompVis/stable-diffusion-v1-4" + # "stabilityai/stable-diffusion-2-1" + # "SG161222/Realistic_Vision_V2.0" + # "darkstorm2150/Protogen_x3.4_Official_Release" + # "prompthero/openjourney" + # "naclbit/trinart_stable_diffusion_v2" + # "hakurei/waifu-diffusion" model_name = "darkstorm2150/Protogen_x3.4_Official_Release" + # inpainting model candidates: + # "runwayml/stable-diffusion-inpainting" inpainting_model_name = "runwayml/stable-diffusion-inpainting" model = Model(model_name, inpainting_model_name, logger) model.reduce_memory() model.load() - return model + + config = Config() + return model, config def main(): - logger = Logger(name="rl_trader") + logger = Logger(name=LOGGER_NAME) + + model, config = prepare(logger) + text2img = Text2Img(model, config) - model = prepare(logger) - input("confirm...") diff --git a/utilities/BUILD b/utilities/BUILD index 650114c..5cac05c 100644 --- a/utilities/BUILD +++ b/utilities/BUILD @@ -1,28 +1,51 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test") -package(default_visibility = ["//visibility:public"]) +package(default_visibility=["//visibility:public"]) py_library( - name = "memory", - srcs = ["memory.py"], + name="memory", + srcs=["memory.py"], ) py_library( - name = "model", - srcs = ["model.py"], - deps = [ - ":memory", - ":logger", - ], + name="text2img", + srcs=["text2img.py"], + deps=[ + ":config", + ":model", + ], ) py_library( - name = "logger", - srcs = ["logger.py"], + name="config", + srcs=["config.py"], + deps=[ + ":logger", + ":constants", + ], +) + +py_library( + name="constants", + srcs=["constants.py"], +) + +py_library( + name="model", + srcs=["model.py"], + deps=[ + ":memory", + ":logger", + ], +) + +py_library( + name="logger", + srcs=["logger.py"], ) py_test( - name = "logger_test", - srcs = ["logger_test.py"], - deps = [":logger"], + name="logger_test", + srcs=["logger_test.py"], + deps=[":logger"], ) diff --git a/utilities/config.py b/utilities/config.py new file mode 100644 index 0000000..33eef98 --- /dev/null +++ b/utilities/config.py @@ -0,0 +1,93 @@ +import random +import time + +from utilities.constants import KEY_GUIDANCE_SCALE +from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT +from utilities.constants import KEY_HEIGHT +from utilities.constants import VALUE_HEIGHT_DEFAULT +from utilities.constants import KEY_PREVIEW +from utilities.constants import VALUE_PREVIEW_DEFAULT +from utilities.constants import KEY_SCHEDULER +from utilities.constants import VALUE_SCHEDULER_DEFAULT +from utilities.constants import KEY_SEED +from utilities.constants import VALUE_SEED_DEFAULT +from utilities.constants import KEY_STEPS +from utilities.constants import VALUE_STEPS_DEFAULT +from utilities.constants import KEY_WIDTH +from utilities.constants import VALUE_WIDTH_DEFAULT +from utilities.logger import DummyLogger + + +class Config: + """ + Configuration. + """ + + def __init__(self, logger: DummyLogger = DummyLogger()): + self.__logger = logger + self.__config = {} + + def get_config(self) -> dict: + return self.__config + + def get_guidance_scale(self) -> float: + return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT) + + def set_guidance_scale(self, scale: float): + self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale)) + self.__config[KEY_GUIDANCE_SCALE] = scale + + def get_height(self) -> int: + return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT) + + def set_height(self, value: int): + self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value)) + self.__config[KEY_HEIGHT] = value + + def get_preview(self) -> bool: + return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT) + + def set_preview(self, boolean: bool): + self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean)) + self.__config[KEY_PREVIEW] = boolean + + def get_scheduler(self) -> str: + return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT) + + def set_scheduler(self, scheduler: str): + # choices: + # "Default" + # "DPMSolverMultistepScheduler" + # "LMSDiscreteScheduler" + # "EulerDiscreteScheduler" + # "PNDMScheduler" + # "DDIMScheduler" + if not scheduler: + scheduler = "Default" + self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler)) + self.__config[KEY_SCHEDULER] = scheduler + + def get_seed(self) -> int: + seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT) + if seed == 0: + random.seed(int(time.time_ns())) + seed = random.getrandbits(64) + return seed + + def set_seed(self, seed: int): + self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed)) + self.__config[KEY_SEED] = seed + + def get_steps(self) -> int: + return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT) + + def set_steps(self, steps: int): + self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps)) + self.__config[KEY_STEPS] = steps + + def get_width(self) -> int: + return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT) + + def set_width(self, value: int): + self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value)) + self.__config[KEY_WIDTH] = value diff --git a/utilities/constants.py b/utilities/constants.py new file mode 100644 index 0000000..346e983 --- /dev/null +++ b/utilities/constants.py @@ -0,0 +1,22 @@ +LOGGER_NAME = "main" + +KEY_SEED = "SEED" +VALUE_SEED_DEFAULT = 0 + +KEY_WIDTH = "WIDTH" +VALUE_WIDTH_DEFAULT = 512 + +KEY_HEIGHT = "HEIGHT" +VALUE_HEIGHT_DEFAULT = 512 + +KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE" +VALUE_GUIDANCE_SCALE_DEFAULT = 15.0 + +KEY_STEPS = "STEPS" +VALUE_STEPS_DEFAULT = 100 + +KEY_SCHEDULER = "SCHEDULER" +VALUE_SCHEDULER_DEFAULT = "Default" + +KEY_PREVIEW = "PREVIEW" +VALUE_PREVIEW_DEFAULT = True diff --git a/utilities/logger.py b/utilities/logger.py index 0aa5d58..dd87fda 100644 --- a/utilities/logger.py +++ b/utilities/logger.py @@ -6,7 +6,8 @@ from colorlog import ColoredFormatter FORMATTER_STREAM = ColoredFormatter( - "%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s") + "%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s" +) FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s") VERBOSITY_V = 10 @@ -14,32 +15,32 @@ VERBOSITY_VV = 100 VERBOSITY_VVV = 1000 -def touch(filepath): - ''' +def touch(filepath: str): + """ Behaves similarly as `touch` command in Linux system. Creates an empty file. - ''' + """ try: os.makedirs(os.path.dirname(filepath), exist_ok=True) - open(filepath, 'a').close() + open(filepath, "a").close() return True except BaseException: pass return False -def get_func_name(offset=0): +def get_func_name(offset: int = 0): """ Returns the name of the caller function name. Offset counts the recursion - for example, offset=1 means the caller of the caller. """ - return sys._getframe(1+offset).f_code.co_name + return sys._getframe(1 + offset).f_code.co_name -class DummyLogger(): - ''' +class DummyLogger: + """ DummyLogger does not do anything. - ''' + """ def __init__(self): pass @@ -59,26 +60,26 @@ class DummyLogger(): def debugging_off(self): pass - def info(self, msg=""): + def info(self, msg: str = ""): pass - def error(self, msg=""): + def error(self, msg: str = ""): pass - def warn(self, msg=""): + def warn(self, msg: str = ""): pass - def debug(self, msg="", verbosity=VERBOSITY_V): + def debug(self, msg: str = "", verbosity: int = VERBOSITY_V): pass - def critical(self, msg=""): + def critical(self, msg: str = ""): pass class Logger(DummyLogger): - ''' + """ Logger with specific format - ''' + """ def __init__( self, @@ -86,7 +87,7 @@ class Logger(DummyLogger): filepath: str = "", verbosity: int = VERBOSITY_V, stream_lvl=logging.INFO, - file_lvl=logging.DEBUG + file_lvl=logging.DEBUG, ): self.verbosity = verbosity self.logger = logging.getLogger(name) @@ -129,53 +130,54 @@ class Logger(DummyLogger): def debugging_on(self): self.__streamHandler.setLevel(logging.DEBUG) self.logger.info( - "debug msg printing is on, verbosity: {}".format(self.verbosity)) + "debug msg printing is on, verbosity: {}".format(self.verbosity) + ) def debugging_off(self): self.__streamHandler.setLevel(logging.INFO) self.logger.info("debug msg printing is off") def info(self, msg="", verbosity: int = VERBOSITY_V): - ''' + """ Showing info message. @param msg: the message @param verbosity: the higher the number is, the less important it is. - ''' + """ if verbosity > self.verbosity: return self.logger.info("[{}] {}".format(get_func_name(offset=1), msg)) def warn(self, msg="", verbosity: int = VERBOSITY_V): - ''' + """ Showing warning message. @param msg: the message @param verbosity: the higher the number is, the less important it is. - ''' + """ if verbosity > self.verbosity: return self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg)) def debug(self, msg: str = "", verbosity: int = VERBOSITY_V): - ''' + """ Showing debug message. @param msg: the message @param verbosity: the higher the number is, the less important it is. - ''' + """ if verbosity > self.verbosity: return self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg)) def error(self, msg=""): - ''' + """ Showing error message. @param msg: the message - ''' + """ self.logger.error("[{}] {}".format(get_func_name(offset=1), msg)) self.logger.error(traceback.format_exc()) def critical(self, msg=""): - ''' + """ Showing critical message. @param msg: the message - ''' + """ self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg)) diff --git a/utilities/logger_test.py b/utilities/logger_test.py index 55d1e84..39dc234 100644 --- a/utilities/logger_test.py +++ b/utilities/logger_test.py @@ -5,7 +5,6 @@ from utilities.logger import Logger class TestLogger(unittest.TestCase): - @classmethod def setUpClass(self): self.logger = Logger() @@ -31,5 +30,5 @@ class TestLogger(unittest.TestCase): os.remove(self.log_filepath) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/utilities/memory.py b/utilities/memory.py index 4a4a544..97542a9 100644 --- a/utilities/memory.py +++ b/utilities/memory.py @@ -3,15 +3,15 @@ import torch def empty_memory(): - ''' + """ Performs garbage collection and empty cache in cuda device. - ''' + """ gc.collect() torch.cuda.empty_cache() def tune_for_low_memory(): - ''' + """ Tunes PyTorch to use float16 to reduce memory footprint. - ''' + """ torch.set_default_dtype(torch.float16) diff --git a/utilities/model.py b/utilities/model.py index 26d9ded..96b452f 100644 --- a/utilities/model.py +++ b/utilities/model.py @@ -3,30 +3,39 @@ from diffusers import StableDiffusionPipeline from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionInpaintPipeline -from utilities.logger import Logger +from utilities.logger import DummyLogger from utilities.memory import empty_memory from utilities.memory import tune_for_low_memory -class Model(): - '''Model class.''' - def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True): + +class Model: + """Model class.""" + + def __init__( + self, + model_name: str, + inpainting_model_name: str, + logger: DummyLogger = DummyLogger(), + use_gpu: bool = True, + ): self.model_name = model_name self.inpainting_model_name = inpainting_model_name self.__use_gpu = False if use_gpu and torch.cuda.is_available(): self.__use_gpu = True logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0"))) - self.logger = logger + self.__logger = logger + self.__torch_dtype = "auto" + self.sd_pipeline = None self.img2img_pipeline = None self.inpaint_pipeline = None - self.__torch_dtype = "auto" def use_gpu(self): return self.__use_gpu def reduce_memory(self): - self.logger.info("reduces memory usage by using float16 dtype") + self.__logger.info("reduces memory usage by using float16 dtype") tune_for_low_memory() self.__torch_dtype = torch.float16 @@ -41,21 +50,25 @@ class Model(): model_name, revision=revision, torch_dtype=self.__torch_dtype, - safety_checker=None) + safety_checker=None, + ) except: try: sd_pipeline = StableDiffusionPipeline.from_pretrained( self.model_name, torch_dtype=self.__torch_dtype, - safety_checker=None) + safety_checker=None, + ) except Exception as e: - self.logger.error("failed to load model %s: %s" % (self.model_name, e)) + self.__logger.error( + "failed to load model %s: %s" % (self.model_name, e) + ) if sd_pipeline and self.use_gpu(): sd_pipeline.to("cuda") self.sd_pipeline = sd_pipeline - self.sd_pipeline_scheduler = sd_pipeline.scheduler - - self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components) + self.img2img_pipeline = StableDiffusionImg2ImgPipeline( + **sd_pipeline.components + ) if self.inpainting_model_name: revision = get_revision_from_model_name(self.inpainting_model_name) @@ -65,15 +78,20 @@ class Model(): model_name, revision=revision, torch_dtype=self.__torch_dtype, - safety_checker=None) + safety_checker=None, + ) except: try: inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( self.inpainting_model_name, torch_dtype=self.__torch_dtype, - safety_checker=None) + safety_checker=None, + ) except Exception as e: - self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e)) + self.__logger.error( + "failed to load inpaint model %s: %s" + % (self.inpainting_model_name, e) + ) if inpaint_pipeline and self.use_gpu(): inpaint_pipeline.to("cuda") self.inpaint_pipeline = inpaint_pipeline @@ -81,4 +99,8 @@ class Model(): def get_revision_from_model_name(model_name: str): - return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16" \ No newline at end of file + return ( + "diffusers-115k" + if model_name == "naclbit/trinart_stable_diffusion_v2" + else "fp16" + ) diff --git a/utilities/text2img.py b/utilities/text2img.py new file mode 100644 index 0000000..d548ebf --- /dev/null +++ b/utilities/text2img.py @@ -0,0 +1,20 @@ +from utilities.config import Config +from utilities.model import Model + + +class Text2Img: + """ + Text2Img class. + """ + + def __init__(self, model: Model, config: Config): + self.model = model + self.config = config + + def update_config(config: Config): + self.config = config + + def update_model(model, Model): + self.model = model + + \ No newline at end of file