From c4faf2ae6ca983aa02aeb26d1dd8b032443ea92f Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 6 Jun 2024 18:05:55 +0200 Subject: [PATCH 1/6] fix: add try_parse_bool for env var strings to enable config overrides of boolean values --- modules/config.py | 5 +++-- modules/extra_utils.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/modules/config.py b/modules/config.py index 29a16d6d..0733d716 100644 --- a/modules/config.py +++ b/modules/config.py @@ -2,13 +2,14 @@ import os import json import math import numbers + import args_manager import tempfile import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.extra_utils import makedirs_with_log, get_files_from_folder +from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_parse_bool from modules.flags import OutputFormat, Performance, MetadataScheme @@ -209,7 +210,7 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ v = os.getenv(key) if v is not None: print(f"Environment: {key} = {v}") - config_dict[key] = v + config_dict[key] = try_parse_bool(v) if key not in config_dict: config_dict[key] = default_value diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 9906c820..b4c83061 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -1,4 +1,6 @@ import os +from ast import literal_eval + def makedirs_with_log(path): try: @@ -24,3 +26,10 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): filenames.append(path) return filenames + + +def try_parse_bool(value: str) -> str | bool: + value_eval = literal_eval(value.strip().title()) + if type(value_eval) is bool: + return value_eval + return value From beab2b9d48cee844ed2564d2b345d4936d5ffe28 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 6 Jun 2024 18:20:14 +0200 Subject: [PATCH 2/6] fix: fallback to given value if not parseable --- modules/config.py | 3 ++- modules/extra_utils.py | 11 +++++++---- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/modules/config.py b/modules/config.py index 0733d716..14a6f52b 100644 --- a/modules/config.py +++ b/modules/config.py @@ -209,8 +209,9 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ v = os.getenv(key) if v is not None: + v = try_parse_bool(v) print(f"Environment: {key} = {v}") - config_dict[key] = try_parse_bool(v) + config_dict[key] = v if key not in config_dict: config_dict[key] = default_value diff --git a/modules/extra_utils.py b/modules/extra_utils.py index b4c83061..c4056020 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -29,7 +29,10 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): def try_parse_bool(value: str) -> str | bool: - value_eval = literal_eval(value.strip().title()) - if type(value_eval) is bool: - return value_eval - return value + try: + value_eval = literal_eval(value.strip().title()) + if type(value_eval) is bool: + return value_eval + return value + except ValueError | TypeError: + return value From bef79e3cb40ab2d8d8536a4a5e74ab7d1ef30b71 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 6 Jun 2024 18:37:00 +0200 Subject: [PATCH 3/6] feat: extend eval to all valid types --- modules/config.py | 4 ++-- modules/extra_utils.py | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/modules/config.py b/modules/config.py index 14a6f52b..6bd25d55 100644 --- a/modules/config.py +++ b/modules/config.py @@ -9,7 +9,7 @@ import modules.flags import modules.sdxl_styles from modules.model_loader import load_file_from_url -from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_parse_bool +from modules.extra_utils import makedirs_with_log, get_files_from_folder, try_eval_env_var from modules.flags import OutputFormat, Performance, MetadataScheme @@ -209,7 +209,7 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ v = os.getenv(key) if v is not None: - v = try_parse_bool(v) + v = try_eval_env_var(v) print(f"Environment: {key} = {v}") config_dict[key] = v diff --git a/modules/extra_utils.py b/modules/extra_utils.py index c4056020..00c11091 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -28,11 +28,11 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): return filenames -def try_parse_bool(value: str) -> str | bool: +def try_eval_env_var(value: str, expected_type=None) -> str | bool: try: - value_eval = literal_eval(value.strip().title()) - if type(value_eval) is bool: - return value_eval - return value - except ValueError | TypeError: + value_eval = literal_eval(value.title()) + if expected_type is not None and type(value_eval) is not expected_type: + return value + return value_eval + except: return value From 2186d3e15db6eea3e1e617d7c2572e63f0e99973 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 6 Jun 2024 18:38:49 +0200 Subject: [PATCH 4/6] fix: remove return type --- modules/extra_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 00c11091..72b8a280 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -28,7 +28,7 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): return filenames -def try_eval_env_var(value: str, expected_type=None) -> str | bool: +def try_eval_env_var(value: str, expected_type=None): try: value_eval = literal_eval(value.title()) if expected_type is not None and type(value_eval) is not expected_type: From 09be2c972cb8e42ddd5e8b5d29f427b76db14148 Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 6 Jun 2024 19:13:17 +0200 Subject: [PATCH 5/6] fix: prevent strange type conversions by providing expected type --- modules/config.py | 116 +++++++++++++++++++++++++++-------------- modules/extra_utils.py | 7 ++- 2 files changed, 82 insertions(+), 41 deletions(-) diff --git a/modules/config.py b/modules/config.py index 6bd25d55..e3c427d2 100644 --- a/modules/config.py +++ b/modules/config.py @@ -201,7 +201,7 @@ path_safety_checker = get_dir_or_set_default('path_safety_checker', '../models/s path_outputs = get_path_output() -def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False): +def get_config_item_or_set_default(key, default_value, validator, disable_empty_as_none=False, expected_type=None): global config_dict, visited_keys if key not in visited_keys: @@ -209,7 +209,7 @@ def get_config_item_or_set_default(key, default_value, validator, disable_empty_ v = os.getenv(key) if v is not None: - v = try_eval_env_var(v) + v = try_eval_env_var(v, expected_type) print(f"Environment: {key} = {v}") config_dict[key] = v @@ -254,41 +254,49 @@ temp_path = init_temp_path(get_config_item_or_set_default( key='temp_path', default_value=default_temp_path, validator=lambda x: isinstance(x, str), + expected_type=str ), default_temp_path) temp_path_cleanup_on_launch = get_config_item_or_set_default( key='temp_path_cleanup_on_launch', default_value=True, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_base_model_name = default_model = get_config_item_or_set_default( key='default_model', default_value='model.safetensors', - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) previous_default_models = get_config_item_or_set_default( key='previous_default_models', default_value=[], - validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x) + validator=lambda x: isinstance(x, list) and all(isinstance(k, str) for k in x), + expected_type=list ) default_refiner_model_name = default_refiner = get_config_item_or_set_default( key='default_refiner', default_value='None', - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) default_refiner_switch = get_config_item_or_set_default( key='default_refiner_switch', default_value=0.8, - validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1 + validator=lambda x: isinstance(x, numbers.Number) and 0 <= x <= 1, + expected_type=numbers.Number ) default_loras_min_weight = get_config_item_or_set_default( key='default_loras_min_weight', default_value=-2, - validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 + validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10, + expected_type=numbers.Number ) default_loras_max_weight = get_config_item_or_set_default( key='default_loras_max_weight', default_value=2, - validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10 + validator=lambda x: isinstance(x, numbers.Number) and -10 <= x <= 10, + expected_type=numbers.Number ) default_loras = get_config_item_or_set_default( key='default_loras', @@ -322,38 +330,45 @@ default_loras = get_config_item_or_set_default( validator=lambda x: isinstance(x, list) and all( len(y) == 3 and isinstance(y[0], bool) and isinstance(y[1], str) and isinstance(y[2], numbers.Number) or len(y) == 2 and isinstance(y[0], str) and isinstance(y[1], numbers.Number) - for y in x) + for y in x), + expected_type=list ) default_loras = [(y[0], y[1], y[2]) if len(y) == 3 else (True, y[0], y[1]) for y in default_loras] default_max_lora_number = get_config_item_or_set_default( key='default_max_lora_number', default_value=len(default_loras) if isinstance(default_loras, list) and len(default_loras) > 0 else 5, - validator=lambda x: isinstance(x, int) and x >= 1 + validator=lambda x: isinstance(x, int) and x >= 1, + expected_type=int ) default_cfg_scale = get_config_item_or_set_default( key='default_cfg_scale', default_value=7.0, - validator=lambda x: isinstance(x, numbers.Number) + validator=lambda x: isinstance(x, numbers.Number), + expected_type=numbers.Number ) default_sample_sharpness = get_config_item_or_set_default( key='default_sample_sharpness', default_value=2.0, - validator=lambda x: isinstance(x, numbers.Number) + validator=lambda x: isinstance(x, numbers.Number), + expected_type=numbers.Number ) default_sampler = get_config_item_or_set_default( key='default_sampler', default_value='dpmpp_2m_sde_gpu', - validator=lambda x: x in modules.flags.sampler_list + validator=lambda x: x in modules.flags.sampler_list, + expected_type=str ) default_scheduler = get_config_item_or_set_default( key='default_scheduler', default_value='karras', - validator=lambda x: x in modules.flags.scheduler_list + validator=lambda x: x in modules.flags.scheduler_list, + expected_type=str ) default_vae = get_config_item_or_set_default( key='default_vae', default_value=modules.flags.default_vae, - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) default_styles = get_config_item_or_set_default( key='default_styles', @@ -362,121 +377,144 @@ default_styles = get_config_item_or_set_default( "Fooocus Enhance", "Fooocus Sharp" ], - validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x) + validator=lambda x: isinstance(x, list) and all(y in modules.sdxl_styles.legal_style_names for y in x), + expected_type=list ) default_prompt_negative = get_config_item_or_set_default( key='default_prompt_negative', default_value='', validator=lambda x: isinstance(x, str), - disable_empty_as_none=True + disable_empty_as_none=True, + expected_type=str ) default_prompt = get_config_item_or_set_default( key='default_prompt', default_value='', validator=lambda x: isinstance(x, str), - disable_empty_as_none=True + disable_empty_as_none=True, + expected_type=str ) default_performance = get_config_item_or_set_default( key='default_performance', default_value=Performance.SPEED.value, - validator=lambda x: x in Performance.list() + validator=lambda x: x in Performance.list(), + expected_type=str ) default_advanced_checkbox = get_config_item_or_set_default( key='default_advanced_checkbox', default_value=False, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_max_image_number = get_config_item_or_set_default( key='default_max_image_number', default_value=32, - validator=lambda x: isinstance(x, int) and x >= 1 + validator=lambda x: isinstance(x, int) and x >= 1, + expected_type=int ) default_output_format = get_config_item_or_set_default( key='default_output_format', default_value='png', - validator=lambda x: x in OutputFormat.list() + validator=lambda x: x in OutputFormat.list(), + expected_type=str ) default_image_number = get_config_item_or_set_default( key='default_image_number', default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number + validator=lambda x: isinstance(x, int) and 1 <= x <= default_max_image_number, + expected_type=int ) checkpoint_downloads = get_config_item_or_set_default( key='checkpoint_downloads', default_value={}, - validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()), + expected_type=dict ) lora_downloads = get_config_item_or_set_default( key='lora_downloads', default_value={}, - validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()), + expected_type=dict ) embeddings_downloads = get_config_item_or_set_default( key='embeddings_downloads', default_value={}, - validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()) + validator=lambda x: isinstance(x, dict) and all(isinstance(k, str) and isinstance(v, str) for k, v in x.items()), + expected_type=dict ) available_aspect_ratios = get_config_item_or_set_default( key='available_aspect_ratios', default_value=modules.flags.sdxl_aspect_ratios, - validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1 + validator=lambda x: isinstance(x, list) and all('*' in v for v in x) and len(x) > 1, + expected_type=list ) default_aspect_ratio = get_config_item_or_set_default( key='default_aspect_ratio', default_value='1152*896' if '1152*896' in available_aspect_ratios else available_aspect_ratios[0], - validator=lambda x: x in available_aspect_ratios + validator=lambda x: x in available_aspect_ratios, + expected_type=str ) default_inpaint_engine_version = get_config_item_or_set_default( key='default_inpaint_engine_version', default_value='v2.6', - validator=lambda x: x in modules.flags.inpaint_engine_versions + validator=lambda x: x in modules.flags.inpaint_engine_versions, + expected_type=str ) default_cfg_tsnr = get_config_item_or_set_default( key='default_cfg_tsnr', default_value=7.0, - validator=lambda x: isinstance(x, numbers.Number) + validator=lambda x: isinstance(x, numbers.Number), + expected_type=numbers.Number ) default_clip_skip = get_config_item_or_set_default( key='default_clip_skip', default_value=2, - validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max + validator=lambda x: isinstance(x, int) and 1 <= x <= modules.flags.clip_skip_max, + expected_type=int ) default_overwrite_step = get_config_item_or_set_default( key='default_overwrite_step', default_value=-1, - validator=lambda x: isinstance(x, int) + validator=lambda x: isinstance(x, int), + expected_type=int ) default_overwrite_switch = get_config_item_or_set_default( key='default_overwrite_switch', default_value=-1, - validator=lambda x: isinstance(x, int) + validator=lambda x: isinstance(x, int), + expected_type=int ) example_inpaint_prompts = get_config_item_or_set_default( key='example_inpaint_prompts', default_value=[ 'highly detailed face', 'detailed girl face', 'detailed man face', 'detailed hand', 'beautiful eyes' ], - validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x) + validator=lambda x: isinstance(x, list) and all(isinstance(v, str) for v in x), + expected_type=list ) default_black_out_nsfw = get_config_item_or_set_default( key='default_black_out_nsfw', default_value=False, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_save_metadata_to_images = get_config_item_or_set_default( key='default_save_metadata_to_images', default_value=False, - validator=lambda x: isinstance(x, bool) + validator=lambda x: isinstance(x, bool), + expected_type=bool ) default_metadata_scheme = get_config_item_or_set_default( key='default_metadata_scheme', default_value=MetadataScheme.FOOOCUS.value, - validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x] + validator=lambda x: x in [y[1] for y in modules.flags.metadata_scheme if y[1] == x], + expected_type=str ) metadata_created_by = get_config_item_or_set_default( key='metadata_created_by', default_value='', - validator=lambda x: isinstance(x, str) + validator=lambda x: isinstance(x, str), + expected_type=str ) example_inpaint_prompts = [[x] for x in example_inpaint_prompts] diff --git a/modules/extra_utils.py b/modules/extra_utils.py index 72b8a280..c2dfa810 100644 --- a/modules/extra_utils.py +++ b/modules/extra_utils.py @@ -30,8 +30,11 @@ def get_files_from_folder(folder_path, extensions=None, name_filter=None): def try_eval_env_var(value: str, expected_type=None): try: - value_eval = literal_eval(value.title()) - if expected_type is not None and type(value_eval) is not expected_type: + value_eval = value + if expected_type is bool: + value_eval = value.title() + value_eval = literal_eval(value_eval) + if expected_type is not None and not isinstance(value_eval, expected_type): return value return value_eval except: From d56b0929686e0b4449e7b4b78bf590186eee1cdc Mon Sep 17 00:00:00 2001 From: Manuel Schmid Date: Thu, 6 Jun 2024 19:27:03 +0200 Subject: [PATCH 6/6] feat: add tests --- tests/test_extra_utils.py | 74 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) create mode 100644 tests/test_extra_utils.py diff --git a/tests/test_extra_utils.py b/tests/test_extra_utils.py new file mode 100644 index 00000000..a849aa16 --- /dev/null +++ b/tests/test_extra_utils.py @@ -0,0 +1,74 @@ +import numbers +import os +import unittest + +import modules.flags +from modules import extra_utils + + +class TestUtils(unittest.TestCase): + def test_try_eval_env_var(self): + test_cases = [ + { + "input": ("foo", str), + "output": "foo" + }, + { + "input": ("1", int), + "output": 1 + }, + { + "input": ("1.0", float), + "output": 1.0 + }, + { + "input": ("1", numbers.Number), + "output": 1 + }, + { + "input": ("1.0", numbers.Number), + "output": 1.0 + }, + { + "input": ("true", bool), + "output": True + }, + { + "input": ("True", bool), + "output": True + }, + { + "input": ("false", bool), + "output": False + }, + { + "input": ("False", bool), + "output": False + }, + { + "input": ("True", str), + "output": "True" + }, + { + "input": ("False", str), + "output": "False" + }, + { + "input": ("['a', 'b', 'c']", list), + "output": ['a', 'b', 'c'] + }, + { + "input": ("{'a':1}", dict), + "output": {'a': 1} + }, + { + "input": ("('foo', 1)", tuple), + "output": ('foo', 1) + } + ] + + for test in test_cases: + value, expected_type = test["input"] + expected = test["output"] + actual = extra_utils.try_eval_env_var(value, expected_type) + self.assertEqual(expected, actual)