adds txt2img capability with interactive prompt input
This commit is contained in:
parent
41664e2682
commit
b0e07eecd4
17
main.py
17
main.py
|
|
@ -25,6 +25,7 @@ def prepare(logger: Logger) -> [Model, Config]:
|
||||||
model.load_all()
|
model.load_all()
|
||||||
|
|
||||||
config = Config()
|
config = Config()
|
||||||
|
config.set_output_folder("/tmp/")
|
||||||
return model, config
|
return model, config
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -34,7 +35,21 @@ def main():
|
||||||
model, config = prepare(logger)
|
model, config = prepare(logger)
|
||||||
text2img = Text2Img(model, config)
|
text2img = Text2Img(model, config)
|
||||||
|
|
||||||
input("confirm...")
|
text2img.breakfast()
|
||||||
|
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
prompt = input("Write prompt: ")
|
||||||
|
if not prompt:
|
||||||
|
prompt = "man riding a horse in space"
|
||||||
|
negative_prompt = input("Write negative prompt: ")
|
||||||
|
if not negative_prompt:
|
||||||
|
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
|
||||||
|
text2img.lunch(prompt=prompt, negative_prompt=negative_prompt)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
break
|
||||||
|
except BaseException:
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
accelerate
|
accelerate
|
||||||
colorlog
|
colorlog
|
||||||
diffusers
|
diffusers
|
||||||
|
numpy
|
||||||
|
Pillow
|
||||||
|
scikit-image
|
||||||
torch
|
torch
|
||||||
transformers
|
transformers
|
||||||
|
|
|
||||||
|
|
@ -2,20 +2,6 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")
|
||||||
|
|
||||||
package(default_visibility=["//visibility:public"])
|
package(default_visibility=["//visibility:public"])
|
||||||
|
|
||||||
py_library(
|
|
||||||
name="memory",
|
|
||||||
srcs=["memory.py"],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
|
||||||
name="text2img",
|
|
||||||
srcs=["text2img.py"],
|
|
||||||
deps=[
|
|
||||||
":config",
|
|
||||||
":model",
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name="config",
|
name="config",
|
||||||
srcs=["config.py"],
|
srcs=["config.py"],
|
||||||
|
|
@ -31,13 +17,8 @@ py_library(
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
name="model",
|
name = "images",
|
||||||
srcs=["model.py"],
|
srcs = ["images.py"],
|
||||||
deps=[
|
|
||||||
":constants",
|
|
||||||
":memory",
|
|
||||||
":logger",
|
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
|
@ -50,3 +31,42 @@ py_test(
|
||||||
srcs=["logger_test.py"],
|
srcs=["logger_test.py"],
|
||||||
deps=[":logger"],
|
deps=[":logger"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="memory",
|
||||||
|
srcs=["memory.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="model",
|
||||||
|
srcs=["model.py"],
|
||||||
|
deps=[
|
||||||
|
":constants",
|
||||||
|
":memory",
|
||||||
|
":logger",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="text2img",
|
||||||
|
srcs=["text2img.py"],
|
||||||
|
deps=[
|
||||||
|
":config",
|
||||||
|
":logger",
|
||||||
|
":images",
|
||||||
|
":memory",
|
||||||
|
":model",
|
||||||
|
":times",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name="times",
|
||||||
|
srcs=["times.py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name="times_test",
|
||||||
|
srcs=["times_test.py"],
|
||||||
|
deps=[":times"],
|
||||||
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
import random
|
import random
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from utilities.constants import KEY_OUTPUT_FOLDER
|
||||||
|
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
|
||||||
from utilities.constants import KEY_GUIDANCE_SCALE
|
from utilities.constants import KEY_GUIDANCE_SCALE
|
||||||
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
|
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
|
||||||
from utilities.constants import KEY_HEIGHT
|
from utilities.constants import KEY_HEIGHT
|
||||||
|
|
@ -35,6 +37,13 @@ class Config:
|
||||||
def get_config(self) -> dict:
|
def get_config(self) -> dict:
|
||||||
return self.__config
|
return self.__config
|
||||||
|
|
||||||
|
def get_output_folder(self) -> str:
|
||||||
|
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
|
||||||
|
|
||||||
|
def set_output_folder(self, folder:str):
|
||||||
|
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
|
||||||
|
self.__config[KEY_OUTPUT_FOLDER] = folder
|
||||||
|
|
||||||
def get_guidance_scale(self) -> float:
|
def get_guidance_scale(self) -> float:
|
||||||
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,8 @@
|
||||||
LOGGER_NAME = "main"
|
LOGGER_NAME = "main"
|
||||||
|
|
||||||
|
KEY_OUTPUT_FOLDER = "OUTFOLDER"
|
||||||
|
VALUE_OUTPUT_FOLDER_DEFAULT = ""
|
||||||
|
|
||||||
KEY_SEED = "SEED"
|
KEY_SEED = "SEED"
|
||||||
VALUE_SEED_DEFAULT = 0
|
VALUE_SEED_DEFAULT = 0
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
import base64
|
||||||
|
import os
|
||||||
|
import io
|
||||||
|
from typing import Union
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
|
||||||
|
|
||||||
|
def load_image(image: Union[str, bytes]) -> Union[Image.Image, None]:
|
||||||
|
if isinstance(image, bytes):
|
||||||
|
return Image.open(io.BytesIO(image))
|
||||||
|
elif os.path.isfile(image):
|
||||||
|
with Image.open(image) as im:
|
||||||
|
return Image.fromarray(np.asarray(im))
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool = False) -> bool:
|
||||||
|
if os.path.isfile(filepath) and not override:
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
if isinstance(image, Image.Image):
|
||||||
|
# this is an Image
|
||||||
|
image.save(filepath)
|
||||||
|
else:
|
||||||
|
with open(filepath, "wb") as f:
|
||||||
|
f.write(image)
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def crop_image(image: Image.Image, boundary: tuple) -> Image.Image:
|
||||||
|
'''
|
||||||
|
Crop an image based on boundary defined in boundary tuple.
|
||||||
|
'''
|
||||||
|
return image.crop(boundary)
|
||||||
|
|
||||||
|
|
||||||
|
def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "png") -> str:
|
||||||
|
if isinstance(image, str):
|
||||||
|
# this is a filepath
|
||||||
|
if not os.path.isfile(image):
|
||||||
|
return ""
|
||||||
|
with open(image, "rb") as f:
|
||||||
|
image = f.read()
|
||||||
|
elif isinstance(image, Image.Image):
|
||||||
|
# this is an image
|
||||||
|
rawbytes = io.BytesIO()
|
||||||
|
image.save(rawbytes, format=image_format)
|
||||||
|
image = rawbytes.getvalue()
|
||||||
|
return "data:image/{};base64,".format(image_format) + base64.b64encode(image).decode()
|
||||||
|
|
||||||
|
|
||||||
|
from skimage import io as skimageio
|
||||||
|
from skimage import transform
|
||||||
|
from skimage import img_as_ubyte
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_transform_image_for_torch(img_filepath: str, dimension: tuple = (), force_rgb: bool = True, transpose: bool = True, use_ubyte: bool = False) -> np.ndarray:
|
||||||
|
img = skimageio.imread(img_filepath)
|
||||||
|
if force_rgb:
|
||||||
|
img = img[:, :, :3]
|
||||||
|
if dimension:
|
||||||
|
img = transform.resize(img, dimension)
|
||||||
|
if transpose:
|
||||||
|
# swap color axis because
|
||||||
|
# numpy image: H x W x C
|
||||||
|
# torch image: C x H x W
|
||||||
|
img = img.transpose((2, 0, 1))
|
||||||
|
if use_ubyte:
|
||||||
|
img = img_as_ubyte(img)
|
||||||
|
return np.array(img)
|
||||||
|
|
@ -1,5 +1,12 @@
|
||||||
|
import torch
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
from utilities.config import Config
|
from utilities.config import Config
|
||||||
|
from utilities.images import save_image
|
||||||
|
from utilities.logger import DummyLogger
|
||||||
|
from utilities.memory import empty_memory_cache
|
||||||
from utilities.model import Model
|
from utilities.model import Model
|
||||||
|
from utilities.times import get_epoch_now
|
||||||
|
|
||||||
|
|
||||||
class Text2Img:
|
class Text2Img:
|
||||||
|
|
@ -7,9 +14,48 @@ class Text2Img:
|
||||||
Text2Img class.
|
Text2Img class.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model: Model, config: Config):
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Model,
|
||||||
|
config: Union[Config, None],
|
||||||
|
logger: DummyLogger = DummyLogger(),
|
||||||
|
):
|
||||||
self.model = model
|
self.model = model
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.__logger = logger
|
||||||
|
|
||||||
|
def update_config(self, config: Config):
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
def brunch(self, prompt: str, negative_prompt: str = ""):
|
||||||
|
self.breakfast()
|
||||||
|
self.lunch(prompt, negative_prompt)
|
||||||
|
|
||||||
def breakfast(self):
|
def breakfast(self):
|
||||||
self.model.set_txt2img_scheduler(config.get_scheduler())
|
self.model.set_txt2img_scheduler(self.config.get_scheduler())
|
||||||
|
|
||||||
|
def lunch(self, prompt: str, negative_prompt: str = ""):
|
||||||
|
t = get_epoch_now()
|
||||||
|
seed = self.config.get_seed()
|
||||||
|
self.__logger.info("current seed: {}".format(seed))
|
||||||
|
|
||||||
|
generator = torch.Generator("cuda").manual_seed(seed)
|
||||||
|
|
||||||
|
result = self.model.txt2img_pipeline(
|
||||||
|
prompt=prompt,
|
||||||
|
width=self.config.get_width(),
|
||||||
|
height=self.config.get_height(),
|
||||||
|
negative_prompt=negative_prompt,
|
||||||
|
guidance_scale=self.config.get_guidance_scale(),
|
||||||
|
num_inference_steps=self.config.get_steps(),
|
||||||
|
generator=generator,
|
||||||
|
callback=None,
|
||||||
|
callback_steps=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
out_filepath = "{}/{}.png".format(self.config.get_output_folder(), t)
|
||||||
|
result.images[0].save(out_filepath)
|
||||||
|
|
||||||
|
self.__logger.info("output to file: {}".format(out_filepath))
|
||||||
|
|
||||||
|
empty_memory_cache()
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
import calendar
|
||||||
|
import time
|
||||||
|
|
||||||
|
|
||||||
|
def get_epoch_now() -> int:
|
||||||
|
'''
|
||||||
|
Gets current elapsed epoch since 1970-1-1, 00:00 UTC.
|
||||||
|
The returned epoch is timezone-independent.
|
||||||
|
'''
|
||||||
|
return int(time.time())
|
||||||
|
|
||||||
|
|
||||||
|
def epoch_to_time(epoch: int, localTime: bool = False) -> time.struct_time:
|
||||||
|
'''
|
||||||
|
Converts epoch to the time struct.
|
||||||
|
@param localTime sets to return local time zone considered time.
|
||||||
|
'''
|
||||||
|
return time.localtime(epoch) if localTime else time.gmtime(epoch)
|
||||||
|
|
||||||
|
|
||||||
|
def epoch_to_string(epoch: int, localTime: bool = False, customFormat: str = "%Y-%m-%dT%H:%M:%S") -> str:
|
||||||
|
'''
|
||||||
|
Converts epoch to the string in the form of something like "2019-01-11T10:30:00".
|
||||||
|
@param localTime sets to return local time zone considered string.
|
||||||
|
'''
|
||||||
|
return time.strftime(customFormat, epoch_to_time(epoch, localTime=localTime))
|
||||||
|
|
||||||
|
|
||||||
|
def epoch_to_date(epoch: int, localTime: bool = False) -> str:
|
||||||
|
'''
|
||||||
|
Converts epoch to the date of the epoch.
|
||||||
|
@param localTime sets to return local time zone considered date.
|
||||||
|
'''
|
||||||
|
return time.strftime('%Y-%m-%d', epoch_to_time(epoch, localTime=localTime))
|
||||||
|
|
||||||
|
|
||||||
|
def epoch_to_yearmonth(epoch: int, localTime: bool = False) -> str:
|
||||||
|
'''
|
||||||
|
Converts epoch to the year-month of the epoch.
|
||||||
|
@param localTime sets to return local time zone considered date.
|
||||||
|
'''
|
||||||
|
return time.strftime('%Y-%m', epoch_to_time(epoch, localTime=localTime))
|
||||||
|
|
||||||
|
|
||||||
|
def string_to_epoch(timeString: str, localTime: bool = False, dashOnly: bool = False) -> int:
|
||||||
|
'''
|
||||||
|
Converts time string from something like "2019-01-11T10:30:00" to the epoch.
|
||||||
|
@param localTime sets to return local time zone considered epoch.
|
||||||
|
'''
|
||||||
|
if 'Z' == timeString[-1]:
|
||||||
|
timeString = timeString[:-1]
|
||||||
|
localTime = False
|
||||||
|
if dashOnly:
|
||||||
|
t = time.strptime(timeString, '%Y-%m-%dT%H-%M-%S')
|
||||||
|
else:
|
||||||
|
t = time.strptime(timeString, '%Y-%m-%dT%H:%M:%S')
|
||||||
|
offset = 0
|
||||||
|
if localTime:
|
||||||
|
currentT = get_epoch_now()
|
||||||
|
offset = int(calendar.timegm(epoch_to_time(currentT)) -
|
||||||
|
calendar.timegm(epoch_to_time(currentT, localTime=True)))
|
||||||
|
return int(calendar.timegm(t)) + offset
|
||||||
|
|
||||||
|
|
||||||
|
def date_to_epoch(timeString: str, localTime: bool = False) -> int:
|
||||||
|
'''
|
||||||
|
Converts date string from something like "2019-01-11" to the epoch.
|
||||||
|
@param localTime sets to return local time zone considered epoch.
|
||||||
|
'''
|
||||||
|
return string_to_epoch('{}T0:0:0'.format(timeString), localTime=localTime)
|
||||||
|
|
||||||
|
|
||||||
|
def time_to_epoch(t, localTime: bool = False) -> int:
|
||||||
|
'''
|
||||||
|
Converts epoch to the time struct.
|
||||||
|
@param localTime sets to return local time zone considered time.
|
||||||
|
'''
|
||||||
|
offset = 0
|
||||||
|
if localTime:
|
||||||
|
currentT = get_epoch_now()
|
||||||
|
offset = int(calendar.timegm(epoch_to_time(currentT)) -
|
||||||
|
calendar.timegm(epoch_to_time(currentT, localTime=True)))
|
||||||
|
return int(calendar.timegm(t)) + offset
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_seconds(wait_sec: int):
|
||||||
|
time.sleep(wait_sec)
|
||||||
|
|
||||||
|
|
||||||
|
def is_epoch_weekend(epoch: int, localTime: bool = False) -> bool:
|
||||||
|
'''
|
||||||
|
Checks if given epoch is during a weekend (Sat/Sun).
|
||||||
|
@param localTime sets to return local time zone considered time.
|
||||||
|
'''
|
||||||
|
return epoch_to_time(epoch, localTime=localTime).tm_wday >= 5
|
||||||
|
|
||||||
|
|
||||||
|
class Timer():
|
||||||
|
def __init__(self):
|
||||||
|
self.__start_time = None
|
||||||
|
self.__stop_time = None
|
||||||
|
|
||||||
|
def start(self):
|
||||||
|
self.__start_time = get_epoch_now()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
if self.__start_time is None:
|
||||||
|
raise ValueError("start() must be called first")
|
||||||
|
self.__stop_time = get_epoch_now()
|
||||||
|
|
||||||
|
def elapsed_seconds(self) -> int:
|
||||||
|
if self.__stop_time is None:
|
||||||
|
return get_epoch_now() - self.__start_time
|
||||||
|
return self.__stop_time - self.__start_time
|
||||||
|
|
||||||
|
def remaining_seconds_estimation(self, current_progress: float) -> int:
|
||||||
|
return int(self.elapsed_seconds() / current_progress)
|
||||||
|
|
@ -0,0 +1,61 @@
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from utilities.times import date_to_epoch
|
||||||
|
from utilities.times import epoch_to_time
|
||||||
|
from utilities.times import epoch_to_string
|
||||||
|
from utilities.times import epoch_to_date
|
||||||
|
from utilities.times import get_epoch_now
|
||||||
|
from utilities.times import string_to_epoch
|
||||||
|
from utilities.times import time_to_epoch
|
||||||
|
from utilities.times import Timer
|
||||||
|
from utilities.times import wait_for_seconds
|
||||||
|
|
||||||
|
|
||||||
|
class TestTimes(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_epoch_translation(self):
|
||||||
|
current_epoch = get_epoch_now()
|
||||||
|
current_time = epoch_to_time(current_epoch, localTime=False)
|
||||||
|
epoch_from_time = time_to_epoch(current_time, localTime=False)
|
||||||
|
self.assertEqual(epoch_from_time, current_epoch)
|
||||||
|
|
||||||
|
current_local_time = epoch_to_time(current_epoch, localTime=True)
|
||||||
|
epoch_from_local_time = time_to_epoch(
|
||||||
|
current_local_time, localTime=True)
|
||||||
|
self.assertEqual(epoch_from_local_time, current_epoch)
|
||||||
|
|
||||||
|
current_date = epoch_to_date(current_epoch, localTime=False)
|
||||||
|
epoch_from_date = date_to_epoch(current_date, localTime=False)
|
||||||
|
self.assertEqual(epoch_from_date // 86400, current_epoch // 86400)
|
||||||
|
|
||||||
|
current_local_date = epoch_to_date(current_epoch, localTime=True)
|
||||||
|
epoch_from_date_local = date_to_epoch(
|
||||||
|
current_local_date, localTime=True)
|
||||||
|
self.assertEqual(epoch_from_date_local //
|
||||||
|
86400, current_epoch // 86400)
|
||||||
|
|
||||||
|
current_time_string = epoch_to_string(current_epoch, localTime=False)
|
||||||
|
epoch_from_string = string_to_epoch(
|
||||||
|
current_time_string, localTime=False)
|
||||||
|
self.assertEqual(epoch_from_string, current_epoch)
|
||||||
|
|
||||||
|
current_local_time_string = epoch_to_string(
|
||||||
|
current_epoch, localTime=True)
|
||||||
|
epoch_from_string_local = string_to_epoch(
|
||||||
|
current_local_time_string, localTime=True)
|
||||||
|
self.assertEqual(epoch_from_string_local, current_epoch)
|
||||||
|
|
||||||
|
def test_timer(self):
|
||||||
|
t = Timer()
|
||||||
|
self.assertRaises(ValueError, t.stop)
|
||||||
|
t.start()
|
||||||
|
wait_for_seconds(2)
|
||||||
|
self.assertEqual(t.elapsed_seconds(), 2)
|
||||||
|
wait_for_seconds(3)
|
||||||
|
t.stop()
|
||||||
|
self.assertEqual(t.elapsed_seconds(), 5)
|
||||||
|
self.assertEqual(t.remaining_seconds_estimation(0.5), 10)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
unittest.main()
|
||||||
Loading…
Reference in New Issue