adds config and text2img barebone, python formats

This commit is contained in:
HappyZ 2023-04-26 23:07:24 -07:00
parent e40a7c02f6
commit d4e00c3377
10 changed files with 280 additions and 79 deletions

16
BUILD
View File

@ -1,13 +1,15 @@
load("@rules_python//python:defs.bzl", "py_binary") load("@rules_python//python:defs.bzl", "py_binary")
load("@subpar//:subpar.bzl", "par_binary") load("@subpar//:subpar.bzl", "par_binary")
package(default_visibility = ["//visibility:public"]) package(default_visibility=["//visibility:public"])
par_binary( par_binary(
name = 'main', name="main",
srcs = ["main.py"], srcs=["main.py"],
deps = [ deps=[
"//utilities:logger", "//utilities:constants",
"//utilities:memory", "//utilities:logger",
], "//utilities:model",
"//utilities:text2img",
],
) )

30
main.py
View File

@ -1,21 +1,39 @@
from utilities.model import Model from utilities.constants import LOGGER_NAME
from utilities.logger import Logger from utilities.logger import Logger
from utilities.model import Model
from utilities.config import Config
from utilities.text2img import Text2Img
def prepare(logger: Logger):
def prepare(logger: Logger) -> [Model, Config]:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "darkstorm2150/Protogen_x3.4_Official_Release" model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting" inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger) model = Model(model_name, inpainting_model_name, logger)
model.reduce_memory() model.reduce_memory()
model.load() model.load()
return model
config = Config()
return model, config
def main(): def main():
logger = Logger(name="rl_trader") logger = Logger(name=LOGGER_NAME)
model, config = prepare(logger)
text2img = Text2Img(model, config)
model = prepare(logger)
input("confirm...") input("confirm...")

View File

@ -1,28 +1,51 @@
load("@rules_python//python:defs.bzl", "py_library", "py_test") load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = ["//visibility:public"]) package(default_visibility=["//visibility:public"])
py_library( py_library(
name = "memory", name="memory",
srcs = ["memory.py"], srcs=["memory.py"],
) )
py_library( py_library(
name = "model", name="text2img",
srcs = ["model.py"], srcs=["text2img.py"],
deps = [ deps=[
":memory", ":config",
":logger", ":model",
], ],
) )
py_library( py_library(
name = "logger", name="config",
srcs = ["logger.py"], srcs=["config.py"],
deps=[
":logger",
":constants",
],
)
py_library(
name="constants",
srcs=["constants.py"],
)
py_library(
name="model",
srcs=["model.py"],
deps=[
":memory",
":logger",
],
)
py_library(
name="logger",
srcs=["logger.py"],
) )
py_test( py_test(
name = "logger_test", name="logger_test",
srcs = ["logger_test.py"], srcs=["logger_test.py"],
deps = [":logger"], deps=[":logger"],
) )

93
utilities/config.py Normal file
View File

@ -0,0 +1,93 @@
import random
import time
from utilities.constants import KEY_GUIDANCE_SCALE
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
from utilities.constants import KEY_HEIGHT
from utilities.constants import VALUE_HEIGHT_DEFAULT
from utilities.constants import KEY_PREVIEW
from utilities.constants import VALUE_PREVIEW_DEFAULT
from utilities.constants import KEY_SCHEDULER
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import KEY_SEED
from utilities.constants import VALUE_SEED_DEFAULT
from utilities.constants import KEY_STEPS
from utilities.constants import VALUE_STEPS_DEFAULT
from utilities.constants import KEY_WIDTH
from utilities.constants import VALUE_WIDTH_DEFAULT
from utilities.logger import DummyLogger
class Config:
"""
Configuration.
"""
def __init__(self, logger: DummyLogger = DummyLogger()):
self.__logger = logger
self.__config = {}
def get_config(self) -> dict:
return self.__config
def get_guidance_scale(self) -> float:
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
def set_guidance_scale(self, scale: float):
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
self.__config[KEY_GUIDANCE_SCALE] = scale
def get_height(self) -> int:
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
def set_height(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
self.__config[KEY_HEIGHT] = value
def get_preview(self) -> bool:
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
def set_preview(self, boolean: bool):
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
self.__config[KEY_PREVIEW] = boolean
def get_scheduler(self) -> str:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
def set_scheduler(self, scheduler: str):
# choices:
# "Default"
# "DPMSolverMultistepScheduler"
# "LMSDiscreteScheduler"
# "EulerDiscreteScheduler"
# "PNDMScheduler"
# "DDIMScheduler"
if not scheduler:
scheduler = "Default"
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
self.__config[KEY_SCHEDULER] = scheduler
def get_seed(self) -> int:
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
if seed == 0:
random.seed(int(time.time_ns()))
seed = random.getrandbits(64)
return seed
def set_seed(self, seed: int):
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
self.__config[KEY_SEED] = seed
def get_steps(self) -> int:
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
def set_steps(self, steps: int):
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
self.__config[KEY_STEPS] = steps
def get_width(self) -> int:
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
def set_width(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
self.__config[KEY_WIDTH] = value

22
utilities/constants.py Normal file
View File

@ -0,0 +1,22 @@
LOGGER_NAME = "main"
KEY_SEED = "SEED"
VALUE_SEED_DEFAULT = 0
KEY_WIDTH = "WIDTH"
VALUE_WIDTH_DEFAULT = 512
KEY_HEIGHT = "HEIGHT"
VALUE_HEIGHT_DEFAULT = 512
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
KEY_STEPS = "STEPS"
VALUE_STEPS_DEFAULT = 100
KEY_SCHEDULER = "SCHEDULER"
VALUE_SCHEDULER_DEFAULT = "Default"
KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True

View File

@ -6,7 +6,8 @@ from colorlog import ColoredFormatter
FORMATTER_STREAM = ColoredFormatter( FORMATTER_STREAM = ColoredFormatter(
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s") "%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s"
)
FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s") FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s")
VERBOSITY_V = 10 VERBOSITY_V = 10
@ -14,32 +15,32 @@ VERBOSITY_VV = 100
VERBOSITY_VVV = 1000 VERBOSITY_VVV = 1000
def touch(filepath): def touch(filepath: str):
''' """
Behaves similarly as `touch` command in Linux system. Behaves similarly as `touch` command in Linux system.
Creates an empty file. Creates an empty file.
''' """
try: try:
os.makedirs(os.path.dirname(filepath), exist_ok=True) os.makedirs(os.path.dirname(filepath), exist_ok=True)
open(filepath, 'a').close() open(filepath, "a").close()
return True return True
except BaseException: except BaseException:
pass pass
return False return False
def get_func_name(offset=0): def get_func_name(offset: int = 0):
""" """
Returns the name of the caller function name. Returns the name of the caller function name.
Offset counts the recursion - for example, offset=1 means the caller of the caller. Offset counts the recursion - for example, offset=1 means the caller of the caller.
""" """
return sys._getframe(1+offset).f_code.co_name return sys._getframe(1 + offset).f_code.co_name
class DummyLogger(): class DummyLogger:
''' """
DummyLogger does not do anything. DummyLogger does not do anything.
''' """
def __init__(self): def __init__(self):
pass pass
@ -59,26 +60,26 @@ class DummyLogger():
def debugging_off(self): def debugging_off(self):
pass pass
def info(self, msg=""): def info(self, msg: str = ""):
pass pass
def error(self, msg=""): def error(self, msg: str = ""):
pass pass
def warn(self, msg=""): def warn(self, msg: str = ""):
pass pass
def debug(self, msg="", verbosity=VERBOSITY_V): def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
pass pass
def critical(self, msg=""): def critical(self, msg: str = ""):
pass pass
class Logger(DummyLogger): class Logger(DummyLogger):
''' """
Logger with specific format Logger with specific format
''' """
def __init__( def __init__(
self, self,
@ -86,7 +87,7 @@ class Logger(DummyLogger):
filepath: str = "", filepath: str = "",
verbosity: int = VERBOSITY_V, verbosity: int = VERBOSITY_V,
stream_lvl=logging.INFO, stream_lvl=logging.INFO,
file_lvl=logging.DEBUG file_lvl=logging.DEBUG,
): ):
self.verbosity = verbosity self.verbosity = verbosity
self.logger = logging.getLogger(name) self.logger = logging.getLogger(name)
@ -129,53 +130,54 @@ class Logger(DummyLogger):
def debugging_on(self): def debugging_on(self):
self.__streamHandler.setLevel(logging.DEBUG) self.__streamHandler.setLevel(logging.DEBUG)
self.logger.info( self.logger.info(
"debug msg printing is on, verbosity: {}".format(self.verbosity)) "debug msg printing is on, verbosity: {}".format(self.verbosity)
)
def debugging_off(self): def debugging_off(self):
self.__streamHandler.setLevel(logging.INFO) self.__streamHandler.setLevel(logging.INFO)
self.logger.info("debug msg printing is off") self.logger.info("debug msg printing is off")
def info(self, msg="", verbosity: int = VERBOSITY_V): def info(self, msg="", verbosity: int = VERBOSITY_V):
''' """
Showing info message. Showing info message.
@param msg: the message @param msg: the message
@param verbosity: the higher the number is, the less important it is. @param verbosity: the higher the number is, the less important it is.
''' """
if verbosity > self.verbosity: if verbosity > self.verbosity:
return return
self.logger.info("[{}] {}".format(get_func_name(offset=1), msg)) self.logger.info("[{}] {}".format(get_func_name(offset=1), msg))
def warn(self, msg="", verbosity: int = VERBOSITY_V): def warn(self, msg="", verbosity: int = VERBOSITY_V):
''' """
Showing warning message. Showing warning message.
@param msg: the message @param msg: the message
@param verbosity: the higher the number is, the less important it is. @param verbosity: the higher the number is, the less important it is.
''' """
if verbosity > self.verbosity: if verbosity > self.verbosity:
return return
self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg)) self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg))
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V): def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
''' """
Showing debug message. Showing debug message.
@param msg: the message @param msg: the message
@param verbosity: the higher the number is, the less important it is. @param verbosity: the higher the number is, the less important it is.
''' """
if verbosity > self.verbosity: if verbosity > self.verbosity:
return return
self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg)) self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg))
def error(self, msg=""): def error(self, msg=""):
''' """
Showing error message. Showing error message.
@param msg: the message @param msg: the message
''' """
self.logger.error("[{}] {}".format(get_func_name(offset=1), msg)) self.logger.error("[{}] {}".format(get_func_name(offset=1), msg))
self.logger.error(traceback.format_exc()) self.logger.error(traceback.format_exc())
def critical(self, msg=""): def critical(self, msg=""):
''' """
Showing critical message. Showing critical message.
@param msg: the message @param msg: the message
''' """
self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg)) self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg))

View File

@ -5,7 +5,6 @@ from utilities.logger import Logger
class TestLogger(unittest.TestCase): class TestLogger(unittest.TestCase):
@classmethod @classmethod
def setUpClass(self): def setUpClass(self):
self.logger = Logger() self.logger = Logger()
@ -31,5 +30,5 @@ class TestLogger(unittest.TestCase):
os.remove(self.log_filepath) os.remove(self.log_filepath)
if __name__ == '__main__': if __name__ == "__main__":
unittest.main() unittest.main()

View File

@ -3,15 +3,15 @@ import torch
def empty_memory(): def empty_memory():
''' """
Performs garbage collection and empty cache in cuda device. Performs garbage collection and empty cache in cuda device.
''' """
gc.collect() gc.collect()
torch.cuda.empty_cache() torch.cuda.empty_cache()
def tune_for_low_memory(): def tune_for_low_memory():
''' """
Tunes PyTorch to use float16 to reduce memory footprint. Tunes PyTorch to use float16 to reduce memory footprint.
''' """
torch.set_default_dtype(torch.float16) torch.set_default_dtype(torch.float16)

View File

@ -3,30 +3,39 @@ from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline from diffusers import StableDiffusionInpaintPipeline
from utilities.logger import Logger from utilities.logger import DummyLogger
from utilities.memory import empty_memory from utilities.memory import empty_memory
from utilities.memory import tune_for_low_memory from utilities.memory import tune_for_low_memory
class Model():
'''Model class.''' class Model:
def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True): """Model class."""
def __init__(
self,
model_name: str,
inpainting_model_name: str,
logger: DummyLogger = DummyLogger(),
use_gpu: bool = True,
):
self.model_name = model_name self.model_name = model_name
self.inpainting_model_name = inpainting_model_name self.inpainting_model_name = inpainting_model_name
self.__use_gpu = False self.__use_gpu = False
if use_gpu and torch.cuda.is_available(): if use_gpu and torch.cuda.is_available():
self.__use_gpu = True self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0"))) logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
self.logger = logger self.__logger = logger
self.__torch_dtype = "auto"
self.sd_pipeline = None self.sd_pipeline = None
self.img2img_pipeline = None self.img2img_pipeline = None
self.inpaint_pipeline = None self.inpaint_pipeline = None
self.__torch_dtype = "auto"
def use_gpu(self): def use_gpu(self):
return self.__use_gpu return self.__use_gpu
def reduce_memory(self): def reduce_memory(self):
self.logger.info("reduces memory usage by using float16 dtype") self.__logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory() tune_for_low_memory()
self.__torch_dtype = torch.float16 self.__torch_dtype = torch.float16
@ -41,21 +50,25 @@ class Model():
model_name, model_name,
revision=revision, revision=revision,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None) safety_checker=None,
)
except: except:
try: try:
sd_pipeline = StableDiffusionPipeline.from_pretrained( sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name, self.model_name,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None) safety_checker=None,
)
except Exception as e: except Exception as e:
self.logger.error("failed to load model %s: %s" % (self.model_name, e)) self.__logger.error(
"failed to load model %s: %s" % (self.model_name, e)
)
if sd_pipeline and self.use_gpu(): if sd_pipeline and self.use_gpu():
sd_pipeline.to("cuda") sd_pipeline.to("cuda")
self.sd_pipeline = sd_pipeline self.sd_pipeline = sd_pipeline
self.sd_pipeline_scheduler = sd_pipeline.scheduler self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**sd_pipeline.components
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components) )
if self.inpainting_model_name: if self.inpainting_model_name:
revision = get_revision_from_model_name(self.inpainting_model_name) revision = get_revision_from_model_name(self.inpainting_model_name)
@ -65,15 +78,20 @@ class Model():
model_name, model_name,
revision=revision, revision=revision,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None) safety_checker=None,
)
except: except:
try: try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained( inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name, self.inpainting_model_name,
torch_dtype=self.__torch_dtype, torch_dtype=self.__torch_dtype,
safety_checker=None) safety_checker=None,
)
except Exception as e: except Exception as e:
self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e)) self.__logger.error(
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if inpaint_pipeline and self.use_gpu(): if inpaint_pipeline and self.use_gpu():
inpaint_pipeline.to("cuda") inpaint_pipeline.to("cuda")
self.inpaint_pipeline = inpaint_pipeline self.inpaint_pipeline = inpaint_pipeline
@ -81,4 +99,8 @@ class Model():
def get_revision_from_model_name(model_name: str): def get_revision_from_model_name(model_name: str):
return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16" return (
"diffusers-115k"
if model_name == "naclbit/trinart_stable_diffusion_v2"
else "fp16"
)

20
utilities/text2img.py Normal file
View File

@ -0,0 +1,20 @@
from utilities.config import Config
from utilities.model import Model
class Text2Img:
"""
Text2Img class.
"""
def __init__(self, model: Model, config: Config):
self.model = model
self.config = config
def update_config(config: Config):
self.config = config
def update_model(model, Model):
self.model = model