adds txt2img capability with interactive prompt input

This commit is contained in:
HappyZ 2023-04-27 18:55:45 -07:00
parent 41664e2682
commit b0e07eecd4
9 changed files with 371 additions and 24 deletions

17
main.py
View File

@ -25,6 +25,7 @@ def prepare(logger: Logger) -> [Model, Config]:
model.load_all() model.load_all()
config = Config() config = Config()
config.set_output_folder("/tmp/")
return model, config return model, config
@ -34,7 +35,21 @@ def main():
model, config = prepare(logger) model, config = prepare(logger)
text2img = Text2Img(model, config) text2img = Text2Img(model, config)
input("confirm...") text2img.breakfast()
while True:
try:
prompt = input("Write prompt: ")
if not prompt:
prompt = "man riding a horse in space"
negative_prompt = input("Write negative prompt: ")
if not negative_prompt:
negative_prompt = "(deformed iris, deformed pupils, semi-realistic, cgi, 3d, render, sketch, cartoon, drawing, anime:1.4), text, close up, cropped, out of frame, worst quality, low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers, long neck"
text2img.lunch(prompt=prompt, negative_prompt=negative_prompt)
except KeyboardInterrupt:
break
except BaseException:
raise
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -1,5 +1,8 @@
accelerate accelerate
colorlog colorlog
diffusers diffusers
numpy
Pillow
scikit-image
torch torch
transformers transformers

View File

@ -2,20 +2,6 @@ load("@rules_python//python:defs.bzl", "py_library", "py_test")
package(default_visibility=["//visibility:public"]) package(default_visibility=["//visibility:public"])
py_library(
name="memory",
srcs=["memory.py"],
)
py_library(
name="text2img",
srcs=["text2img.py"],
deps=[
":config",
":model",
],
)
py_library( py_library(
name="config", name="config",
srcs=["config.py"], srcs=["config.py"],
@ -31,13 +17,8 @@ py_library(
) )
py_library( py_library(
name="model", name = "images",
srcs=["model.py"], srcs = ["images.py"],
deps=[
":constants",
":memory",
":logger",
],
) )
py_library( py_library(
@ -50,3 +31,42 @@ py_test(
srcs=["logger_test.py"], srcs=["logger_test.py"],
deps=[":logger"], deps=[":logger"],
) )
py_library(
name="memory",
srcs=["memory.py"],
)
py_library(
name="model",
srcs=["model.py"],
deps=[
":constants",
":memory",
":logger",
],
)
py_library(
name="text2img",
srcs=["text2img.py"],
deps=[
":config",
":logger",
":images",
":memory",
":model",
":times",
],
)
py_library(
name="times",
srcs=["times.py"],
)
py_test(
name="times_test",
srcs=["times_test.py"],
deps=[":times"],
)

View File

@ -1,6 +1,8 @@
import random import random
import time import time
from utilities.constants import KEY_OUTPUT_FOLDER
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
from utilities.constants import KEY_GUIDANCE_SCALE from utilities.constants import KEY_GUIDANCE_SCALE
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
from utilities.constants import KEY_HEIGHT from utilities.constants import KEY_HEIGHT
@ -35,6 +37,13 @@ class Config:
def get_config(self) -> dict: def get_config(self) -> dict:
return self.__config return self.__config
def get_output_folder(self) -> str:
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
def set_output_folder(self, folder:str):
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
self.__config[KEY_OUTPUT_FOLDER] = folder
def get_guidance_scale(self) -> float: def get_guidance_scale(self) -> float:
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT) return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)

View File

@ -1,5 +1,8 @@
LOGGER_NAME = "main" LOGGER_NAME = "main"
KEY_OUTPUT_FOLDER = "OUTFOLDER"
VALUE_OUTPUT_FOLDER_DEFAULT = ""
KEY_SEED = "SEED" KEY_SEED = "SEED"
VALUE_SEED_DEFAULT = 0 VALUE_SEED_DEFAULT = 0

73
utilities/images.py Normal file
View File

@ -0,0 +1,73 @@
import base64
import os
import io
from typing import Union
import numpy as np
from PIL import Image
def load_image(image: Union[str, bytes]) -> Union[Image.Image, None]:
if isinstance(image, bytes):
return Image.open(io.BytesIO(image))
elif os.path.isfile(image):
with Image.open(image) as im:
return Image.fromarray(np.asarray(im))
return None
def save_image(image: Union[bytes, Image.Image], filepath: str, override: bool = False) -> bool:
if os.path.isfile(filepath) and not override:
return False
try:
if isinstance(image, Image.Image):
# this is an Image
image.save(filepath)
else:
with open(filepath, "wb") as f:
f.write(image)
except OSError:
return False
return True
def crop_image(image: Image.Image, boundary: tuple) -> Image.Image:
'''
Crop an image based on boundary defined in boundary tuple.
'''
return image.crop(boundary)
def image_to_base64(image: Union[bytes, str, Image.Image], image_format: str = "png") -> str:
if isinstance(image, str):
# this is a filepath
if not os.path.isfile(image):
return ""
with open(image, "rb") as f:
image = f.read()
elif isinstance(image, Image.Image):
# this is an image
rawbytes = io.BytesIO()
image.save(rawbytes, format=image_format)
image = rawbytes.getvalue()
return "data:image/{};base64,".format(image_format) + base64.b64encode(image).decode()
from skimage import io as skimageio
from skimage import transform
from skimage import img_as_ubyte
def load_and_transform_image_for_torch(img_filepath: str, dimension: tuple = (), force_rgb: bool = True, transpose: bool = True, use_ubyte: bool = False) -> np.ndarray:
img = skimageio.imread(img_filepath)
if force_rgb:
img = img[:, :, :3]
if dimension:
img = transform.resize(img, dimension)
if transpose:
# swap color axis because
# numpy image: H x W x C
# torch image: C x H x W
img = img.transpose((2, 0, 1))
if use_ubyte:
img = img_as_ubyte(img)
return np.array(img)

View File

@ -1,5 +1,12 @@
import torch
from typing import Union
from utilities.config import Config from utilities.config import Config
from utilities.images import save_image
from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model from utilities.model import Model
from utilities.times import get_epoch_now
class Text2Img: class Text2Img:
@ -7,9 +14,48 @@ class Text2Img:
Text2Img class. Text2Img class.
""" """
def __init__(self, model: Model, config: Config): def __init__(
self,
model: Model,
config: Union[Config, None],
logger: DummyLogger = DummyLogger(),
):
self.model = model self.model = model
self.config = config self.config = config
self.__logger = logger
def update_config(self, config: Config):
self.config = config
def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast()
self.lunch(prompt, negative_prompt)
def breakfast(self): def breakfast(self):
self.model.set_txt2img_scheduler(config.get_scheduler()) self.model.set_txt2img_scheduler(self.config.get_scheduler())
def lunch(self, prompt: str, negative_prompt: str = ""):
t = get_epoch_now()
seed = self.config.get_seed()
self.__logger.info("current seed: {}".format(seed))
generator = torch.Generator("cuda").manual_seed(seed)
result = self.model.txt2img_pipeline(
prompt=prompt,
width=self.config.get_width(),
height=self.config.get_height(),
negative_prompt=negative_prompt,
guidance_scale=self.config.get_guidance_scale(),
num_inference_steps=self.config.get_steps(),
generator=generator,
callback=None,
callback_steps=10,
)
out_filepath = "{}/{}.png".format(self.config.get_output_folder(), t)
result.images[0].save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache()

117
utilities/times.py Normal file
View File

@ -0,0 +1,117 @@
import calendar
import time
def get_epoch_now() -> int:
'''
Gets current elapsed epoch since 1970-1-1, 00:00 UTC.
The returned epoch is timezone-independent.
'''
return int(time.time())
def epoch_to_time(epoch: int, localTime: bool = False) -> time.struct_time:
'''
Converts epoch to the time struct.
@param localTime sets to return local time zone considered time.
'''
return time.localtime(epoch) if localTime else time.gmtime(epoch)
def epoch_to_string(epoch: int, localTime: bool = False, customFormat: str = "%Y-%m-%dT%H:%M:%S") -> str:
'''
Converts epoch to the string in the form of something like "2019-01-11T10:30:00".
@param localTime sets to return local time zone considered string.
'''
return time.strftime(customFormat, epoch_to_time(epoch, localTime=localTime))
def epoch_to_date(epoch: int, localTime: bool = False) -> str:
'''
Converts epoch to the date of the epoch.
@param localTime sets to return local time zone considered date.
'''
return time.strftime('%Y-%m-%d', epoch_to_time(epoch, localTime=localTime))
def epoch_to_yearmonth(epoch: int, localTime: bool = False) -> str:
'''
Converts epoch to the year-month of the epoch.
@param localTime sets to return local time zone considered date.
'''
return time.strftime('%Y-%m', epoch_to_time(epoch, localTime=localTime))
def string_to_epoch(timeString: str, localTime: bool = False, dashOnly: bool = False) -> int:
'''
Converts time string from something like "2019-01-11T10:30:00" to the epoch.
@param localTime sets to return local time zone considered epoch.
'''
if 'Z' == timeString[-1]:
timeString = timeString[:-1]
localTime = False
if dashOnly:
t = time.strptime(timeString, '%Y-%m-%dT%H-%M-%S')
else:
t = time.strptime(timeString, '%Y-%m-%dT%H:%M:%S')
offset = 0
if localTime:
currentT = get_epoch_now()
offset = int(calendar.timegm(epoch_to_time(currentT)) -
calendar.timegm(epoch_to_time(currentT, localTime=True)))
return int(calendar.timegm(t)) + offset
def date_to_epoch(timeString: str, localTime: bool = False) -> int:
'''
Converts date string from something like "2019-01-11" to the epoch.
@param localTime sets to return local time zone considered epoch.
'''
return string_to_epoch('{}T0:0:0'.format(timeString), localTime=localTime)
def time_to_epoch(t, localTime: bool = False) -> int:
'''
Converts epoch to the time struct.
@param localTime sets to return local time zone considered time.
'''
offset = 0
if localTime:
currentT = get_epoch_now()
offset = int(calendar.timegm(epoch_to_time(currentT)) -
calendar.timegm(epoch_to_time(currentT, localTime=True)))
return int(calendar.timegm(t)) + offset
def wait_for_seconds(wait_sec: int):
time.sleep(wait_sec)
def is_epoch_weekend(epoch: int, localTime: bool = False) -> bool:
'''
Checks if given epoch is during a weekend (Sat/Sun).
@param localTime sets to return local time zone considered time.
'''
return epoch_to_time(epoch, localTime=localTime).tm_wday >= 5
class Timer():
def __init__(self):
self.__start_time = None
self.__stop_time = None
def start(self):
self.__start_time = get_epoch_now()
def stop(self):
if self.__start_time is None:
raise ValueError("start() must be called first")
self.__stop_time = get_epoch_now()
def elapsed_seconds(self) -> int:
if self.__stop_time is None:
return get_epoch_now() - self.__start_time
return self.__stop_time - self.__start_time
def remaining_seconds_estimation(self, current_progress: float) -> int:
return int(self.elapsed_seconds() / current_progress)

61
utilities/times_test.py Normal file
View File

@ -0,0 +1,61 @@
import unittest
from utilities.times import date_to_epoch
from utilities.times import epoch_to_time
from utilities.times import epoch_to_string
from utilities.times import epoch_to_date
from utilities.times import get_epoch_now
from utilities.times import string_to_epoch
from utilities.times import time_to_epoch
from utilities.times import Timer
from utilities.times import wait_for_seconds
class TestTimes(unittest.TestCase):
def test_epoch_translation(self):
current_epoch = get_epoch_now()
current_time = epoch_to_time(current_epoch, localTime=False)
epoch_from_time = time_to_epoch(current_time, localTime=False)
self.assertEqual(epoch_from_time, current_epoch)
current_local_time = epoch_to_time(current_epoch, localTime=True)
epoch_from_local_time = time_to_epoch(
current_local_time, localTime=True)
self.assertEqual(epoch_from_local_time, current_epoch)
current_date = epoch_to_date(current_epoch, localTime=False)
epoch_from_date = date_to_epoch(current_date, localTime=False)
self.assertEqual(epoch_from_date // 86400, current_epoch // 86400)
current_local_date = epoch_to_date(current_epoch, localTime=True)
epoch_from_date_local = date_to_epoch(
current_local_date, localTime=True)
self.assertEqual(epoch_from_date_local //
86400, current_epoch // 86400)
current_time_string = epoch_to_string(current_epoch, localTime=False)
epoch_from_string = string_to_epoch(
current_time_string, localTime=False)
self.assertEqual(epoch_from_string, current_epoch)
current_local_time_string = epoch_to_string(
current_epoch, localTime=True)
epoch_from_string_local = string_to_epoch(
current_local_time_string, localTime=True)
self.assertEqual(epoch_from_string_local, current_epoch)
def test_timer(self):
t = Timer()
self.assertRaises(ValueError, t.stop)
t.start()
wait_for_seconds(2)
self.assertEqual(t.elapsed_seconds(), 2)
wait_for_seconds(3)
t.stop()
self.assertEqual(t.elapsed_seconds(), 5)
self.assertEqual(t.remaining_seconds_estimation(0.5), 10)
if __name__ == '__main__':
unittest.main()