117 lines
4.3 KiB
Python
117 lines
4.3 KiB
Python
import random
|
|
import time
|
|
|
|
from utilities.constants import KEY_OUTPUT_FOLDER
|
|
from utilities.constants import VALUE_OUTPUT_FOLDER_DEFAULT
|
|
from utilities.constants import KEY_GUIDANCE_SCALE
|
|
from utilities.constants import VALUE_GUIDANCE_SCALE_DEFAULT
|
|
from utilities.constants import KEY_HEIGHT
|
|
from utilities.constants import VALUE_HEIGHT_DEFAULT
|
|
from utilities.constants import KEY_PREVIEW
|
|
from utilities.constants import VALUE_PREVIEW_DEFAULT
|
|
from utilities.constants import KEY_SCHEDULER
|
|
from utilities.constants import VALUE_SCHEDULER_DEFAULT
|
|
from utilities.constants import VALUE_SCHEDULER_DDIM
|
|
from utilities.constants import VALUE_SCHEDULER_DPM_SOLVER_MULTISTEP
|
|
from utilities.constants import VALUE_SCHEDULER_EULER_DISCRETE
|
|
from utilities.constants import VALUE_SCHEDULER_LMS_DISCRETE
|
|
from utilities.constants import VALUE_SCHEDULER_PNDM
|
|
from utilities.constants import KEY_SEED
|
|
from utilities.constants import VALUE_SEED_DEFAULT
|
|
from utilities.constants import KEY_STEPS
|
|
from utilities.constants import VALUE_STEPS_DEFAULT
|
|
from utilities.constants import KEY_WIDTH
|
|
from utilities.constants import VALUE_WIDTH_DEFAULT
|
|
from utilities.constants import OPTIONAL_KEYS
|
|
from utilities.logger import DummyLogger
|
|
|
|
|
|
class Config:
|
|
"""
|
|
Configuration.
|
|
"""
|
|
|
|
def __init__(self, logger: DummyLogger = DummyLogger()):
|
|
self.__logger = logger
|
|
self.__config = {}
|
|
|
|
def get_config(self) -> dict:
|
|
return self.__config
|
|
|
|
def set_config(self, config: dict):
|
|
for key in config:
|
|
if key not in OPTIONAL_KEYS:
|
|
continue
|
|
self.__config[key.upper()] = config[key]
|
|
return self
|
|
|
|
def get_output_folder(self) -> str:
|
|
return self.__config.get(KEY_OUTPUT_FOLDER, VALUE_OUTPUT_FOLDER_DEFAULT)
|
|
|
|
def set_output_folder(self, folder:str):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_OUTPUT_FOLDER, self.get_output_folder(), folder))
|
|
self.__config[KEY_OUTPUT_FOLDER] = folder
|
|
return self
|
|
|
|
def get_guidance_scale(self) -> float:
|
|
return self.__config.get(KEY_GUIDANCE_SCALE, VALUE_GUIDANCE_SCALE_DEFAULT)
|
|
|
|
def set_guidance_scale(self, scale: float):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_GUIDANCE_SCALE, self.get_guidance_scale(), scale))
|
|
self.__config[KEY_GUIDANCE_SCALE] = scale
|
|
return self
|
|
|
|
def get_height(self) -> int:
|
|
return self.__config.get(KEY_HEIGHT, VALUE_HEIGHT_DEFAULT)
|
|
|
|
def set_height(self, value: int):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_HEIGHT, self.get_height(), value))
|
|
self.__config[KEY_HEIGHT] = value
|
|
return self
|
|
|
|
def get_preview(self) -> bool:
|
|
return self.__config.get(KEY_PREVIEW, VALUE_PREVIEW_DEFAULT)
|
|
|
|
def set_preview(self, boolean: bool):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_PREVIEW, self.get_preview(), boolean))
|
|
self.__config[KEY_PREVIEW] = boolean
|
|
return self
|
|
|
|
def get_scheduler(self) -> str:
|
|
return self.__config.get(KEY_SCHEDULER, VALUE_SCHEDULER_DEFAULT)
|
|
|
|
def set_scheduler(self, scheduler: str):
|
|
if not scheduler:
|
|
scheduler = VALUE_SCHEDULER_DEFAULT
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_SCHEDULER, self.get_scheduler(), scheduler))
|
|
self.__config[KEY_SCHEDULER] = scheduler
|
|
return self
|
|
|
|
def get_seed(self) -> int:
|
|
seed = self.__config.get(KEY_SEED, VALUE_SEED_DEFAULT)
|
|
if seed == 0:
|
|
random.seed(int(time.time_ns()))
|
|
seed = random.getrandbits(64)
|
|
return seed
|
|
|
|
def set_seed(self, seed: int):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_SEED, self.get_seed(), seed))
|
|
self.__config[KEY_SEED] = seed
|
|
return self
|
|
|
|
def get_steps(self) -> int:
|
|
return self.__config.get(KEY_STEPS, VALUE_STEPS_DEFAULT)
|
|
|
|
def set_steps(self, steps: int):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_STEPS, self.get_steps(), steps))
|
|
self.__config[KEY_STEPS] = steps
|
|
return self
|
|
|
|
def get_width(self) -> int:
|
|
return self.__config.get(KEY_WIDTH, VALUE_WIDTH_DEFAULT)
|
|
|
|
def set_width(self, value: int):
|
|
self.__logger.info("{} changed from {} to {}".format(KEY_WIDTH, self.get_width(), value))
|
|
self.__config[KEY_WIDTH] = value
|
|
return self
|