stable-diffusion-for-fun/utilities/text2img.py

76 lines
2.3 KiB
Python

import torch
from typing import Union
from utilities.constants import BASE64IMAGE
from utilities.constants import KEY_SEED
from utilities.constants import KEY_WIDTH
from utilities.constants import KEY_HEIGHT
from utilities.constants import KEY_STEPS
from utilities.config import Config
from utilities.logger import DummyLogger
from utilities.memory import empty_memory_cache
from utilities.model import Model
from utilities.times import get_epoch_now
from utilities.images import image_to_base64
class Text2Img:
"""
Text2Img class.
"""
def __init__(
self,
model: Model,
output_folder: str = "",
logger: DummyLogger = DummyLogger(),
):
self.model = model
self.__device = "cpu" if not self.model.use_gpu() else "cuda"
self.__output_folder = output_folder
self.__logger = logger
def brunch(self, prompt: str, negative_prompt: str = ""):
self.breakfast()
self.lunch(prompt, negative_prompt)
def breakfast(self):
pass
def lunch(
self, prompt: str, negative_prompt: str = "", config: Config = Config()
) -> dict:
self.model.set_txt2img_scheduler(config.get_scheduler())
t = get_epoch_now()
seed = config.get_seed()
generator = torch.Generator(self.__device).manual_seed(seed)
self.__logger.info("current seed: {}".format(seed))
result = self.model.txt2img_pipeline(
prompt=prompt,
width=config.get_width(),
height=config.get_height(),
negative_prompt=negative_prompt,
guidance_scale=config.get_guidance_scale(),
num_inference_steps=config.get_steps(),
generator=generator,
callback=None,
callback_steps=10,
)
if self.__output_folder:
out_filepath = "{}/{}.png".format(self.__output_folder, t)
result.images[0].save(out_filepath)
self.__logger.info("output to file: {}".format(out_filepath))
empty_memory_cache()
return {
BASE64IMAGE: image_to_base64(result.images[0]),
KEY_SEED: str(seed),
KEY_WIDTH: config.get_width(),
KEY_HEIGHT: config.get_height(),
KEY_STEPS: config.get_steps(),
}