adds config and text2img barebone, python formats
This commit is contained in:
parent
e40a7c02f6
commit
d4e00c3377
12
BUILD
12
BUILD
|
|
@ -1,13 +1,15 @@
|
|||
load("@rules_python//python:defs.bzl", "py_binary")
|
||||
load("@subpar//:subpar.bzl", "par_binary")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
package(default_visibility=["//visibility:public"])
|
||||
|
||||
par_binary(
|
||||
name = 'main',
|
||||
srcs = ["main.py"],
|
||||
deps = [
|
||||
name="main",
|
||||
srcs=["main.py"],
|
||||
deps=[
|
||||
"//utilities:constants",
|
||||
"//utilities:logger",
|
||||
"//utilities:memory",
|
||||
"//utilities:model",
|
||||
"//utilities:text2img",
|
||||
],
|
||||
)
|
||||
|
|
|
|||
28
main.py
28
main.py
|
|
@ -1,20 +1,38 @@
|
|||
from utilities.model import Model
|
||||
from utilities.constants import LOGGER_NAME
|
||||
from utilities.logger import Logger
|
||||
from utilities.model import Model
|
||||
from utilities.config import Config
|
||||
from utilities.text2img import Text2Img
|
||||
|
||||
def prepare(logger: Logger):
|
||||
|
||||
def prepare(logger: Logger) -> [Model, Config]:
|
||||
# model candidates:
|
||||
# "runwayml/stable-diffusion-v1-5"
|
||||
# "CompVis/stable-diffusion-v1-4"
|
||||
# "stabilityai/stable-diffusion-2-1"
|
||||
# "SG161222/Realistic_Vision_V2.0"
|
||||
# "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||
# "prompthero/openjourney"
|
||||
# "naclbit/trinart_stable_diffusion_v2"
|
||||
# "hakurei/waifu-diffusion"
|
||||
model_name = "darkstorm2150/Protogen_x3.4_Official_Release"
|
||||
# inpainting model candidates:
|
||||
# "runwayml/stable-diffusion-inpainting"
|
||||
inpainting_model_name = "runwayml/stable-diffusion-inpainting"
|
||||
|
||||
model = Model(model_name, inpainting_model_name, logger)
|
||||
model.reduce_memory()
|
||||
model.load()
|
||||
return model
|
||||
|
||||
config = Config()
|
||||
return model, config
|
||||
|
||||
|
||||
def main():
|
||||
logger = Logger(name="rl_trader")
|
||||
logger = Logger(name=LOGGER_NAME)
|
||||
|
||||
model = prepare(logger)
|
||||
model, config = prepare(logger)
|
||||
text2img = Text2Img(model, config)
|
||||
|
||||
input("confirm...")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,28 +1,51 @@
|
|||
load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||
|
||||
package(default_visibility = ["//visibility:public"])
|
||||
package(default_visibility=["//visibility:public"])
|
||||
|
||||
py_library(
|
||||
name = "memory",
|
||||
srcs = ["memory.py"],
|
||||
name="memory",
|
||||
srcs=["memory.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "model",
|
||||
srcs = ["model.py"],
|
||||
deps = [
|
||||
name="text2img",
|
||||
srcs=["text2img.py"],
|
||||
deps=[
|
||||
":config",
|
||||
":model",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="config",
|
||||
srcs=["config.py"],
|
||||
deps=[
|
||||
":logger",
|
||||
":constants",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="constants",
|
||||
srcs=["constants.py"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name="model",
|
||||
srcs=["model.py"],
|
||||
deps=[
|
||||
":memory",
|
||||
":logger",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "logger",
|
||||
srcs = ["logger.py"],
|
||||
name="logger",
|
||||
srcs=["logger.py"],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "logger_test",
|
||||
srcs = ["logger_test.py"],
|
||||
deps = [":logger"],
|
||||
name="logger_test",
|
||||
srcs=["logger_test.py"],
|
||||
deps=[":logger"],
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,93 @@
|
|||
import random
|
||||
import time
|
||||
|
||||
from utilities.constants import KEY_GUIDANCE_SCALE
|
||||
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
|
||||
from utilities.constants import KEY_HEIGHT
|
||||
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
||||
from utilities.constants import KEY_PREVIEW
|
||||
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
||||
from utilities.constants import KEY_SCHEDULER
|
||||
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
||||
from utilities.constants import KEY_SEED
|
||||
from utilities.constants import VALUE_SEED_DEFAULT
|
||||
from utilities.constants import KEY_STEPS
|
||||
from utilities.constants import VALUE_STEPS_DEFAULT
|
||||
from utilities.constants import KEY_WIDTH
|
||||
from utilities.constants import VALUE_WIDTH_DEFAULT
|
||||
from utilities.logger import DummyLogger
|
||||
|
||||
|
||||
class Config:
|
||||
"""
|
||||
Configuration.
|
||||
"""
|
||||
|
||||
def __init__(self, logger: DummyLogger = DummyLogger()):
|
||||
self.__logger = logger
|
||||
self.__config = {}
|
||||
|
||||
def get_config(self) -> dict:
|
||||
return self.__config
|
||||
|
||||
def get_guidance_scale(self) -> float:
|
||||
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
||||
|
||||
def set_guidance_scale(self, scale: float):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
|
||||
self.__config[KEY_GUIDANCE_SCALE] = scale
|
||||
|
||||
def get_height(self) -> int:
|
||||
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
|
||||
|
||||
def set_height(self, value: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
|
||||
self.__config[KEY_HEIGHT] = value
|
||||
|
||||
def get_preview(self) -> bool:
|
||||
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
||||
|
||||
def set_preview(self, boolean: bool):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
|
||||
self.__config[KEY_PREVIEW] = boolean
|
||||
|
||||
def get_scheduler(self) -> str:
|
||||
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
||||
|
||||
def set_scheduler(self, scheduler: str):
|
||||
# choices:
|
||||
# "Default"
|
||||
# "DPMSolverMultistepScheduler"
|
||||
# "LMSDiscreteScheduler"
|
||||
# "EulerDiscreteScheduler"
|
||||
# "PNDMScheduler"
|
||||
# "DDIMScheduler"
|
||||
if not scheduler:
|
||||
scheduler = "Default"
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
||||
self.__config[KEY_SCHEDULER] = scheduler
|
||||
|
||||
def get_seed(self) -> int:
|
||||
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
|
||||
if seed == 0:
|
||||
random.seed(int(time.time_ns()))
|
||||
seed = random.getrandbits(64)
|
||||
return seed
|
||||
|
||||
def set_seed(self, seed: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
|
||||
self.__config[KEY_SEED] = seed
|
||||
|
||||
def get_steps(self) -> int:
|
||||
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
|
||||
|
||||
def set_steps(self, steps: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
|
||||
self.__config[KEY_STEPS] = steps
|
||||
|
||||
def get_width(self) -> int:
|
||||
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
|
||||
|
||||
def set_width(self, value: int):
|
||||
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
|
||||
self.__config[KEY_WIDTH] = value
|
||||
|
|
@ -0,0 +1,22 @@
|
|||
LOGGER_NAME = "main"
|
||||
|
||||
KEY_SEED = "SEED"
|
||||
VALUE_SEED_DEFAULT = 0
|
||||
|
||||
KEY_WIDTH = "WIDTH"
|
||||
VALUE_WIDTH_DEFAULT = 512
|
||||
|
||||
KEY_HEIGHT = "HEIGHT"
|
||||
VALUE_HEIGHT_DEFAULT = 512
|
||||
|
||||
KEY_GUIDANCE_SCALE = "GUIDANCE_SCALE"
|
||||
VALUE_GUIDANCE_SCALE_DEFAULT = 15.0
|
||||
|
||||
KEY_STEPS = "STEPS"
|
||||
VALUE_STEPS_DEFAULT = 100
|
||||
|
||||
KEY_SCHEDULER = "SCHEDULER"
|
||||
VALUE_SCHEDULER_DEFAULT = "Default"
|
||||
|
||||
KEY_PREVIEW = "PREVIEW"
|
||||
VALUE_PREVIEW_DEFAULT = True
|
||||
|
|
@ -6,7 +6,8 @@ from colorlog import ColoredFormatter
|
|||
|
||||
|
||||
FORMATTER_STREAM = ColoredFormatter(
|
||||
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s")
|
||||
"%(asctime)s %(log_color)s%(levelname)-.1s[%(name)s] %(message)s%(reset)s"
|
||||
)
|
||||
FORMATTER_FILE = logging.Formatter("%(asctime)s %(levelname)-.1s. %(message)s")
|
||||
|
||||
VERBOSITY_V = 10
|
||||
|
|
@ -14,32 +15,32 @@ VERBOSITY_VV = 100
|
|||
VERBOSITY_VVV = 1000
|
||||
|
||||
|
||||
def touch(filepath):
|
||||
'''
|
||||
def touch(filepath: str):
|
||||
"""
|
||||
Behaves similarly as `touch` command in Linux system.
|
||||
Creates an empty file.
|
||||
'''
|
||||
"""
|
||||
try:
|
||||
os.makedirs(os.path.dirname(filepath), exist_ok=True)
|
||||
open(filepath, 'a').close()
|
||||
open(filepath, "a").close()
|
||||
return True
|
||||
except BaseException:
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def get_func_name(offset=0):
|
||||
def get_func_name(offset: int = 0):
|
||||
"""
|
||||
Returns the name of the caller function name.
|
||||
Offset counts the recursion - for example, offset=1 means the caller of the caller.
|
||||
"""
|
||||
return sys._getframe(1+offset).f_code.co_name
|
||||
return sys._getframe(1 + offset).f_code.co_name
|
||||
|
||||
|
||||
class DummyLogger():
|
||||
'''
|
||||
class DummyLogger:
|
||||
"""
|
||||
DummyLogger does not do anything.
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
@ -59,26 +60,26 @@ class DummyLogger():
|
|||
def debugging_off(self):
|
||||
pass
|
||||
|
||||
def info(self, msg=""):
|
||||
def info(self, msg: str = ""):
|
||||
pass
|
||||
|
||||
def error(self, msg=""):
|
||||
def error(self, msg: str = ""):
|
||||
pass
|
||||
|
||||
def warn(self, msg=""):
|
||||
def warn(self, msg: str = ""):
|
||||
pass
|
||||
|
||||
def debug(self, msg="", verbosity=VERBOSITY_V):
|
||||
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
|
||||
pass
|
||||
|
||||
def critical(self, msg=""):
|
||||
def critical(self, msg: str = ""):
|
||||
pass
|
||||
|
||||
|
||||
class Logger(DummyLogger):
|
||||
'''
|
||||
"""
|
||||
Logger with specific format
|
||||
'''
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
|
@ -86,7 +87,7 @@ class Logger(DummyLogger):
|
|||
filepath: str = "",
|
||||
verbosity: int = VERBOSITY_V,
|
||||
stream_lvl=logging.INFO,
|
||||
file_lvl=logging.DEBUG
|
||||
file_lvl=logging.DEBUG,
|
||||
):
|
||||
self.verbosity = verbosity
|
||||
self.logger = logging.getLogger(name)
|
||||
|
|
@ -129,53 +130,54 @@ class Logger(DummyLogger):
|
|||
def debugging_on(self):
|
||||
self.__streamHandler.setLevel(logging.DEBUG)
|
||||
self.logger.info(
|
||||
"debug msg printing is on, verbosity: {}".format(self.verbosity))
|
||||
"debug msg printing is on, verbosity: {}".format(self.verbosity)
|
||||
)
|
||||
|
||||
def debugging_off(self):
|
||||
self.__streamHandler.setLevel(logging.INFO)
|
||||
self.logger.info("debug msg printing is off")
|
||||
|
||||
def info(self, msg="", verbosity: int = VERBOSITY_V):
|
||||
'''
|
||||
"""
|
||||
Showing info message.
|
||||
@param msg: the message
|
||||
@param verbosity: the higher the number is, the less important it is.
|
||||
'''
|
||||
"""
|
||||
if verbosity > self.verbosity:
|
||||
return
|
||||
self.logger.info("[{}] {}".format(get_func_name(offset=1), msg))
|
||||
|
||||
def warn(self, msg="", verbosity: int = VERBOSITY_V):
|
||||
'''
|
||||
"""
|
||||
Showing warning message.
|
||||
@param msg: the message
|
||||
@param verbosity: the higher the number is, the less important it is.
|
||||
'''
|
||||
"""
|
||||
if verbosity > self.verbosity:
|
||||
return
|
||||
self.logger.warning("[{}] {}".format(get_func_name(offset=1), msg))
|
||||
|
||||
def debug(self, msg: str = "", verbosity: int = VERBOSITY_V):
|
||||
'''
|
||||
"""
|
||||
Showing debug message.
|
||||
@param msg: the message
|
||||
@param verbosity: the higher the number is, the less important it is.
|
||||
'''
|
||||
"""
|
||||
if verbosity > self.verbosity:
|
||||
return
|
||||
self.logger.debug("[{}] {}".format(get_func_name(offset=1), msg))
|
||||
|
||||
def error(self, msg=""):
|
||||
'''
|
||||
"""
|
||||
Showing error message.
|
||||
@param msg: the message
|
||||
'''
|
||||
"""
|
||||
self.logger.error("[{}] {}".format(get_func_name(offset=1), msg))
|
||||
self.logger.error(traceback.format_exc())
|
||||
|
||||
def critical(self, msg=""):
|
||||
'''
|
||||
"""
|
||||
Showing critical message.
|
||||
@param msg: the message
|
||||
'''
|
||||
"""
|
||||
self.logger.critical("[{}] {}".format(get_func_name(offset=1), msg))
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@ from utilities.logger import Logger
|
|||
|
||||
|
||||
class TestLogger(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(self):
|
||||
self.logger = Logger()
|
||||
|
|
@ -31,5 +30,5 @@ class TestLogger(unittest.TestCase):
|
|||
os.remove(self.log_filepath)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
|
|
|||
|
|
@ -3,15 +3,15 @@ import torch
|
|||
|
||||
|
||||
def empty_memory():
|
||||
'''
|
||||
"""
|
||||
Performs garbage collection and empty cache in cuda device.
|
||||
'''
|
||||
"""
|
||||
gc.collect()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def tune_for_low_memory():
|
||||
'''
|
||||
"""
|
||||
Tunes PyTorch to use float16 to reduce memory footprint.
|
||||
'''
|
||||
"""
|
||||
torch.set_default_dtype(torch.float16)
|
||||
|
|
|
|||
|
|
@ -3,30 +3,39 @@ from diffusers import StableDiffusionPipeline
|
|||
from diffusers import StableDiffusionImg2ImgPipeline
|
||||
from diffusers import StableDiffusionInpaintPipeline
|
||||
|
||||
from utilities.logger import Logger
|
||||
from utilities.logger import DummyLogger
|
||||
from utilities.memory import empty_memory
|
||||
from utilities.memory import tune_for_low_memory
|
||||
|
||||
class Model():
|
||||
'''Model class.'''
|
||||
def __init__(self, model_name: str, inpainting_model_name: str, logger: Logger, use_gpu: bool=True):
|
||||
|
||||
class Model:
|
||||
"""Model class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
inpainting_model_name: str,
|
||||
logger: DummyLogger = DummyLogger(),
|
||||
use_gpu: bool = True,
|
||||
):
|
||||
self.model_name = model_name
|
||||
self.inpainting_model_name = inpainting_model_name
|
||||
self.__use_gpu = False
|
||||
if use_gpu and torch.cuda.is_available():
|
||||
self.__use_gpu = True
|
||||
logger.info("running on {}".format(torch.cuda.get_device_name("cuda:0")))
|
||||
self.logger = logger
|
||||
self.__logger = logger
|
||||
self.__torch_dtype = "auto"
|
||||
|
||||
self.sd_pipeline = None
|
||||
self.img2img_pipeline = None
|
||||
self.inpaint_pipeline = None
|
||||
self.__torch_dtype = "auto"
|
||||
|
||||
def use_gpu(self):
|
||||
return self.__use_gpu
|
||||
|
||||
def reduce_memory(self):
|
||||
self.logger.info("reduces memory usage by using float16 dtype")
|
||||
self.__logger.info("reduces memory usage by using float16 dtype")
|
||||
tune_for_low_memory()
|
||||
self.__torch_dtype = torch.float16
|
||||
|
||||
|
|
@ -41,21 +50,25 @@ class Model():
|
|||
model_name,
|
||||
revision=revision,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
safety_checker=None,
|
||||
)
|
||||
except:
|
||||
try:
|
||||
sd_pipeline = StableDiffusionPipeline.from_pretrained(
|
||||
self.model_name,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
safety_checker=None,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error("failed to load model %s: %s" % (self.model_name, e))
|
||||
self.__logger.error(
|
||||
"failed to load model %s: %s" % (self.model_name, e)
|
||||
)
|
||||
if sd_pipeline and self.use_gpu():
|
||||
sd_pipeline.to("cuda")
|
||||
self.sd_pipeline = sd_pipeline
|
||||
self.sd_pipeline_scheduler = sd_pipeline.scheduler
|
||||
|
||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(**sd_pipeline.components)
|
||||
self.img2img_pipeline = StableDiffusionImg2ImgPipeline(
|
||||
**sd_pipeline.components
|
||||
)
|
||||
|
||||
if self.inpainting_model_name:
|
||||
revision = get_revision_from_model_name(self.inpainting_model_name)
|
||||
|
|
@ -65,15 +78,20 @@ class Model():
|
|||
model_name,
|
||||
revision=revision,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
safety_checker=None,
|
||||
)
|
||||
except:
|
||||
try:
|
||||
inpaint_pipeline = StableDiffusionInpaintPipeline.from_pretrained(
|
||||
self.inpainting_model_name,
|
||||
torch_dtype=self.__torch_dtype,
|
||||
safety_checker=None)
|
||||
safety_checker=None,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error("failed to load inpaint model %s: %s" % (self.inpainting_model_name, e))
|
||||
self.__logger.error(
|
||||
"failed to load inpaint model %s: %s"
|
||||
% (self.inpainting_model_name, e)
|
||||
)
|
||||
if inpaint_pipeline and self.use_gpu():
|
||||
inpaint_pipeline.to("cuda")
|
||||
self.inpaint_pipeline = inpaint_pipeline
|
||||
|
|
@ -81,4 +99,8 @@ class Model():
|
|||
|
||||
|
||||
def get_revision_from_model_name(model_name: str):
|
||||
return "diffusers-115k" if model_name == "naclbit/trinart_stable_diffusion_v2" else "fp16"
|
||||
return (
|
||||
"diffusers-115k"
|
||||
if model_name == "naclbit/trinart_stable_diffusion_v2"
|
||||
else "fp16"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
from utilities.config import Config
|
||||
from utilities.model import Model
|
||||
|
||||
|
||||
class Text2Img:
|
||||
"""
|
||||
Text2Img class.
|
||||
"""
|
||||
|
||||
def __init__(self, model: Model, config: Config):
|
||||
self.model = model
|
||||
self.config = config
|
||||
|
||||
def update_config(config: Config):
|
||||
self.config = config
|
||||
|
||||
def update_model(model, Model):
|
||||
self.model = model
|
||||
|
||||
|
||||
Loading…
Reference in New Issue