adds config and text2img barebone, python formats

This commit is contained in:
HappyZ 2023-04-26 23:07:24 -07:00
parent e40a7c02f6
commit d4e00c3377
10 changed files with 280 additions and 79 deletions

12
BUILD
View File

@ -1,13 +1,15 @@
load("@rules_python//python:defs.bzl", "py_binary")
load("@subpar//:subpar.bzl", "par_binary")
package(default_visibility = ["//visibility:public"])
package(default_visibility=["//visibility:public"])
par_binary(
name = 'main',
srcs = ["main.py"],
deps = [
name="main",
srcs=["main.py"],
deps=[
"//utilities:constants",
"//utilities:logger",
"//utilities:memory",
"//utilities:model",
"//utilities:text2img",
],
)

28
main.py
View File

@ -1,20 +1,38 @@
from utilities.model import Model
from utilities.constants import LOGGER_NAME
from utilities.logger import Logger
from utilities.model import Model
from utilities.config import Config
from utilities.text2img import Text2Img
def prepare(logger: Logger):
def prepare(logger: Logger) -> [Model, Config]:
# model candidates:
# "runwayml/stable-diffusion-v1-5"
# "CompVis/stable-diffusion-v1-4"
# "stabilityai/stable-diffusion-2-1"
# "SG161222/Realistic_Vision_V2.0"
# "darkstorm2150/Protogen_x3.4_Official_Release"
# "prompthero/openjourney"
# "naclbit/trinart_stable_diffusion_v2"
# "hakurei/waifu-diffusion"
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
# inpainting model candidates:
# "runwayml/stable-diffusion-inpainting"
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
model = Model(model_name, inpainting_model_name, logger)
model.reduce_memory()
model.load()
return model
config = Config()
return model, config
def main():
logger = Logger(name="rl_trader")
logger = Logger(name=LOGGER_NAME)
model = prepare(logger)
model, config = prepare(logger)
text2img = Text2Img(model, config)
input("confirm...")

View File

@ -1,28 +1,51 @@
load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility = ["//visibility:public"])
package(default_visibility=["//visibility:public"])
py_library(
name = "memory",
srcs = ["memory.py"],
name="memory",
srcs=["memory.py"],
)
py_library(
name = "model",
srcs = ["model.py"],
deps = [
name="text2img",
srcs=["text2img.py"],
deps=[
":config",
":model",
],
)
py_library(
name="config",
srcs=["config.py"],
deps=[
":logger",
":constants",
],
)
py_library(
name="constants",
srcs=["constants.py"],
)
py_library(
name="model",
srcs=["model.py"],
deps=[
":memory",
":logger",
],
)
py_library(
name = "logger",
srcs = ["logger.py"],
name="logger",
srcs=["logger.py"],
)
py_test(
name = "logger_test",
srcs = ["logger_test.py"],
deps = [":logger"],
name="logger_test",
srcs=["logger_test.py"],
deps=[":logger"],
)

93
utilities/config.py Normal file
View File

@ -0,0 +1,93 @@
import random
import time
from utilities.constants import KEY_GUIDANCE_SCALE
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
from utilities.constants import KEY_HEIGHT
from utilities.constants import VALUE_HEIGHT_DEFAULT
from utilities.constants import KEY_PREVIEW
from utilities.constants import VALUE_PREVIEW_DEFAULT
from utilities.constants import KEY_SCHEDULER
from utilities.constants import VALUE_SCHEDULER_DEFAULT
from utilities.constants import KEY_SEED
from utilities.constants import VALUE_SEED_DEFAULT
from utilities.constants import KEY_STEPS
from utilities.constants import VALUE_STEPS_DEFAULT
from utilities.constants import KEY_WIDTH
from utilities.constants import VALUE_WIDTH_DEFAULT
from utilities.logger import DummyLogger
class Config:
"""
Configuration.
"""
def __init__(self, logger: DummyLogger = DummyLogger()):
self.__logger = logger
self.__config = {}
def get_config(self) -> dict:
return self.__config
def get_guidance_scale(self) -> float:
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
def set_guidance_scale(self, scale: float):
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
self.__config[KEY_GUIDANCE_SCALE] = scale
def get_height(self) -> int:
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
def set_height(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
self.__config[KEY_HEIGHT] = value
def get_preview(self) -> bool:
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
def set_preview(self, boolean: bool):
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
self.__config[KEY_PREVIEW] = boolean
def get_scheduler(self) -> str:
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
def set_scheduler(self, scheduler: str):
# choices:
# "Default"
# "DPMSolverMultistepScheduler"
# "LMSDiscreteScheduler"
# "EulerDiscreteScheduler"
# "PNDMScheduler"
# "DDIMScheduler"
if not scheduler:
scheduler = "Default"
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
self.__config[KEY_SCHEDULER] = scheduler
def get_seed(self) -> int:
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
if seed == 0:
random.seed(int(time.time_ns()))
seed = random.getrandbits(64)
return seed
def set_seed(self, seed: int):
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
self.__config[KEY_SEED] = seed
def get_steps(self) -> int:
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
def set_steps(self, steps: int):
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
self.__config[KEY_STEPS] = steps
def get_width(self) -> int:
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
def set_width(self, value: int):
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
self.__config[KEY_WIDTH] = value

22
utilities/constants.py Normal file
View File

@ -0,0 +1,22 @@
LOGGER_NAME = "main"
KEY_SEED = "SEED"
VALUE_SEED_DEFAULT = 0
KEY_WIDTH = "WIDTH"
VALUE_WIDTH_DEFAULT = 512
KEY_HEIGHT = "HEIGHT"
VALUE_HEIGHT_DEFAULT = 512
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
KEY_STEPS = "STEPS"
VALUE_STEPS_DEFAULT = 100
KEY_SCHEDULER = "SCHEDULER"
VALUE_SCHEDULER_DEFAULT = "Default"
KEY_PREVIEW = "PREVIEW"
VALUE_PREVIEW_DEFAULT = True

View File

@ -6,7 +6,8 @@ from colorlog import ColoredFormatter
FORMATTER_STREAM = ColoredFormatter(
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s")
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s"
)
FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s")
VERBOSITY_V = 10
@ -14,32 +15,32 @@ VERBOSITY_VV = 100
VERBOSITY_VVV = 1000
def touch(filepath):
'''
def touch(filepath: str):
"""
Behaves similarly as `touch` command in Linux system.
Creates an empty file.
'''
"""
try:
os.makedirs(os.path.dirname(filepath), exist_ok=True)
open(filepath, 'a').close()
open(filepath, "a").close()
return True
except BaseException:
pass
return False
def get_func_name(offset=0):
def get_func_name(offset: int = 0):
"""
Returns the name of the caller function name.
Offset counts the recursion - for example, offset=1 means the caller of the caller.
"""
return sys._getframe(1+offset).f_code.co_name
return sys._getframe(1 + offset).f_code.co_name
class DummyLogger():
'''
class DummyLogger:
"""
DummyLogger does not do anything.
'''
"""
def __init__(self):
pass
@ -59,26 +60,26 @@ class DummyLogger():
def debugging_off(self):
pass
def info(self, msg=""):
def info(self, msg: str = ""):
pass
def error(self, msg=""):
def error(self, msg: str = ""):
pass
def warn(self, msg=""):
def warn(self, msg: str = ""):
pass
def debug(self, msg="", verbosity=VERBOSITY_V):
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
pass
def critical(self, msg=""):
def critical(self, msg: str = ""):
pass
class Logger(DummyLogger):
'''
"""
Logger with specific format
'''
"""
def __init__(
self,
@ -86,7 +87,7 @@ class Logger(DummyLogger):
filepath: str = "",
verbosity: int = VERBOSITY_V,
stream_lvl=logging.INFO,
file_lvl=logging.DEBUG
file_lvl=logging.DEBUG,
):
self.verbosity = verbosity
self.logger = logging.getLogger(name)
@ -129,53 +130,54 @@ class Logger(DummyLogger):
def debugging_on(self):
self.__streamHandler.setLevel(logging.DEBUG)
self.logger.info(
"debug msg printing is on, verbosity: {}".format(self.verbosity))
"debug msg printing is on, verbosity: {}".format(self.verbosity)
)
def debugging_off(self):
self.__streamHandler.setLevel(logging.INFO)
self.logger.info("debug msg printing is off")
def info(self, msg="", verbosity: int = VERBOSITY_V):
'''
"""
Showing info message.
@param msg: the message
@param verbosity: the higher the number is, the less important it is.
'''
"""
if verbosity > self.verbosity:
return
self.logger.info("[{}] {}".format(get_func_name(offset=1), msg))
def warn(self, msg="", verbosity: int = VERBOSITY_V):
'''
"""
Showing warning message.
@param msg: the message
@param verbosity: the higher the number is, the less important it is.
'''
"""
if verbosity > self.verbosity:
return
self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg))
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
'''
"""
Showing debug message.
@param msg: the message
@param verbosity: the higher the number is, the less important it is.
'''
"""
if verbosity > self.verbosity:
return
self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg))
def error(self, msg=""):
'''
"""
Showing error message.
@param msg: the message
'''
"""
self.logger.error("[{}] {}".format(get_func_name(offset=1), msg))
self.logger.error(traceback.format_exc())
def critical(self, msg=""):
'''
"""
Showing critical message.
@param msg: the message
'''
"""
self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg))

View File

@ -5,7 +5,6 @@ from utilities.logger import Logger
class TestLogger(unittest.TestCase):
@classmethod
def setUpClass(self):
self.logger = Logger()
@ -31,5 +30,5 @@ class TestLogger(unittest.TestCase):
os.remove(self.log_filepath)
if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()

View File

@ -3,15 +3,15 @@ import torch
def empty_memory():
'''
"""
Performs garbage collection and empty cache in cuda device.
'''
"""
gc.collect()
torch.cuda.empty_cache()
def tune_for_low_memory():
'''
"""
Tunes PyTorch to use float16 to reduce memory footprint.
'''
"""
torch.set_default_dtype(torch.float16)

View File

@ -3,30 +3,39 @@ from diffusers import StableDiffusionPipeline
from diffusers import StableDiffusionImg2ImgPipeline
from diffusers import StableDiffusionInpaintPipeline
from utilities.logger import Logger
from utilities.logger import DummyLogger
from utilities.memory import empty_memory
from utilities.memory import tune_for_low_memory
class Model():
'''Model class.'''
def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True):
class Model:
"""Model class."""
def __init__(
self,
model_name: str,
inpainting_model_name: str,
logger: DummyLogger = DummyLogger(),
use_gpu: bool = True,
):
self.model_name = model_name
self.inpainting_model_name = inpainting_model_name
self.__use_gpu = False
if use_gpu and torch.cuda.is_available():
self.__use_gpu = True
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
self.logger = logger
self.__logger = logger
self.__torch_dtype = "auto"
self.sd_pipeline = None
self.img2img_pipeline = None
self.inpaint_pipeline = None
self.__torch_dtype = "auto"
def use_gpu(self):
return self.__use_gpu
def reduce_memory(self):
self.logger.info("reduces memory usage by using float16 dtype")
self.__logger.info("reduces memory usage by using float16 dtype")
tune_for_low_memory()
self.__torch_dtype = torch.float16
@ -41,21 +50,25 @@ class Model():
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None)
safety_checker=None,
)
except:
try:
sd_pipeline = StableDiffusionPipeline.from_pretrained(
self.model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None)
safety_checker=None,
)
except Exception as e:
self.logger.error("failed to load model %s: %s" % (self.model_name, e))
self.__logger.error(
"failed to load model %s: %s" % (self.model_name, e)
)
if sd_pipeline and self.use_gpu():
sd_pipeline.to("cuda")
self.sd_pipeline = sd_pipeline
self.sd_pipeline_scheduler = sd_pipeline.scheduler
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components)
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
**sd_pipeline.components
)
if self.inpainting_model_name:
revision = get_revision_from_model_name(self.inpainting_model_name)
@ -65,15 +78,20 @@ class Model():
model_name,
revision=revision,
torch_dtype=self.__torch_dtype,
safety_checker=None)
safety_checker=None,
)
except:
try:
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
self.inpainting_model_name,
torch_dtype=self.__torch_dtype,
safety_checker=None)
safety_checker=None,
)
except Exception as e:
self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e))
self.__logger.error(
"failed to load inpaint model %s: %s"
% (self.inpainting_model_name, e)
)
if inpaint_pipeline and self.use_gpu():
inpaint_pipeline.to("cuda")
self.inpaint_pipeline = inpaint_pipeline
@ -81,4 +99,8 @@ class Model():
def get_revision_from_model_name(model_name: str):
return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
return (
"diffusers-115k"
if model_name == "naclbit/trinart_stable_diffusion_v2"
else "fp16"
)

20
utilities/text2img.py Normal file
View File

@ -0,0 +1,20 @@
from utilities.config import Config
from utilities.model import Model
class Text2Img:
"""
Text2Img class.
"""
def __init__(self, model: Model, config: Config):
self.model = model
self.config = config
def update_config(config: Config):
self.config = config
def update_model(model, Model):
self.model = model