From ef1999c52c8b0ae7fb26ee4563dfca9cc3b5c5c6 Mon Sep 17 00:00:00 2001 From: dooglewoogle <46539436+dooglewoogle@users.noreply.github.com> Date: Mon, 26 Feb 2024 00:47:14 +1300 Subject: [PATCH] feat: add ability to load checkpoints and loras from multiple locations (#1256) * Add ability to load checkpoints and loras from multiple locations * Found another location a default path is required * feat: use array as default --------- Co-authored-by: Manuel Schmid --- launch.py | 9 ++++----- modules/config.py | 35 +++++++++++++++++++++++++---------- modules/core.py | 3 ++- modules/default_pipeline.py | 5 +++-- modules/util.py | 9 +++++++++ 5 files changed, 43 insertions(+), 18 deletions(-) diff --git a/launch.py b/launch.py index db174f54..4269f1fc 100644 --- a/launch.py +++ b/launch.py @@ -68,7 +68,6 @@ vae_approx_filenames = [ 'https://huggingface.co/lllyasviel/misc/resolve/main/xl-to-v1_interposer-v3.1.safetensors') ] - def ini_args(): from args_manager import args return args @@ -101,9 +100,9 @@ def download_models(): return if not args.always_download_new_model: - if not os.path.exists(os.path.join(config.path_checkpoints, config.default_base_model_name)): + if not os.path.exists(os.path.join(config.paths_checkpoints[0], config.default_base_model_name)): for alternative_model_name in config.previous_default_models: - if os.path.exists(os.path.join(config.path_checkpoints, alternative_model_name)): + if os.path.exists(os.path.join(config.paths_checkpoints[0], alternative_model_name)): print(f'You do not have [{config.default_base_model_name}] but you have [{alternative_model_name}].') print(f'Fooocus will use [{alternative_model_name}] to avoid downloading new models, ' f'but you are not using latest models.') @@ -113,11 +112,11 @@ def download_models(): break for file_name, url in config.checkpoint_downloads.items(): - load_file_from_url(url=url, model_dir=config.path_checkpoints, file_name=file_name) + load_file_from_url(url=url, model_dir=config.paths_checkpoints[0], file_name=file_name) for file_name, url in config.embeddings_downloads.items(): load_file_from_url(url=url, model_dir=config.path_embeddings, file_name=file_name) for file_name, url in config.lora_downloads.items(): - load_file_from_url(url=url, model_dir=config.path_loras, file_name=file_name) + load_file_from_url(url=url, model_dir=config.paths_loras[0], file_name=file_name) return diff --git a/modules/config.py b/modules/config.py index 1f4e82eb..d3be1f21 100644 --- a/modules/config.py +++ b/modules/config.py @@ -114,7 +114,7 @@ def get_path_output() -> str: return path_output -def get_dir_or_set_default(key, default_value): +def get_dir_or_set_default(key, default_value, as_array=False): global config_dict, visited_keys, always_save_keys if key not in visited_keys: @@ -125,18 +125,29 @@ def get_dir_or_set_default(key, default_value): v = config_dict.get(key, None) if isinstance(v, str) and os.path.exists(v) and os.path.isdir(v): + return v if not as_array else [v] + elif isinstance(v, list) and all([os.path.exists(d) and os.path.isdir(d) for d in v]): return v else: if v is not None: print(f'Failed to load config key: {json.dumps({key:v})} is invalid or does not exist; will use {json.dumps({key:default_value})} instead.') - dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value)) - os.makedirs(dp, exist_ok=True) + if isinstance(default_value, list): + dp = [] + for path in default_value: + abs_path = os.path.abspath(os.path.join(os.path.dirname(__file__), path)) + dp.append(abs_path) + os.makedirs(abs_path, exist_ok=True) + else: + dp = os.path.abspath(os.path.join(os.path.dirname(__file__), default_value)) + os.makedirs(dp, exist_ok=True) + if as_array: + dp = [dp] config_dict[key] = dp return dp -path_checkpoints = get_dir_or_set_default('path_checkpoints', '../models/checkpoints/') -path_loras = get_dir_or_set_default('path_loras', '../models/loras/') +paths_checkpoints = get_dir_or_set_default('path_checkpoints', ['../models/checkpoints/'], True) +paths_loras = get_dir_or_set_default('path_loras', ['../models/loras/'], True) path_embeddings = get_dir_or_set_default('path_embeddings', '../models/embeddings/') path_vae_approx = get_dir_or_set_default('path_vae_approx', '../models/vae_approx/') path_upscale_models = get_dir_or_set_default('path_upscale_models', '../models/upscale_models/') @@ -404,14 +415,18 @@ model_filenames = [] lora_filenames = [] -def get_model_filenames(folder_path, name_filter=None): - return get_files_from_folder(folder_path, ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'], name_filter) +def get_model_filenames(folder_paths, name_filter=None): + extensions = ['.pth', '.ckpt', '.bin', '.safetensors', '.fooocus.patch'] + files = [] + for folder in folder_paths: + files += get_files_from_folder(folder, extensions, name_filter) + return files def update_all_model_names(): global model_filenames, lora_filenames - model_filenames = get_model_filenames(path_checkpoints) - lora_filenames = get_model_filenames(path_loras) + model_filenames = get_model_filenames(paths_checkpoints) + lora_filenames = get_model_filenames(paths_loras) return @@ -456,7 +471,7 @@ def downloading_inpaint_models(v): def downloading_sdxl_lcm_lora(): load_file_from_url( url='https://huggingface.co/lllyasviel/misc/resolve/main/sdxl_lcm_lora.safetensors', - model_dir=path_loras, + model_dir=paths_loras[0], file_name='sdxl_lcm_lora.safetensors' ) return 'sdxl_lcm_lora.safetensors' diff --git a/modules/core.py b/modules/core.py index 7a29d988..bfc44966 100644 --- a/modules/core.py +++ b/modules/core.py @@ -18,6 +18,7 @@ from ldm_patched.contrib.external import VAEDecode, EmptyLatentImage, VAEEncode, from ldm_patched.contrib.external_freelunch import FreeU_V2 from ldm_patched.modules.sample import prepare_mask from modules.lora import match_lora +from modules.util import get_file_from_folder_list from ldm_patched.modules.lora import model_lora_keys_unet, model_lora_keys_clip from modules.config import path_embeddings from ldm_patched.contrib.external_model_advanced import ModelSamplingDiscrete @@ -79,7 +80,7 @@ class StableDiffusionModel: if os.path.exists(name): lora_filename = name else: - lora_filename = os.path.join(modules.config.path_loras, name) + lora_filename = get_file_from_folder_list(name, modules.config.paths_loras) if not os.path.exists(lora_filename): print(f'Lora file not found: {lora_filename}') diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index 2f45667c..f8edfae1 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -11,6 +11,7 @@ from extras.expansion import FooocusExpansion from ldm_patched.modules.model_base import SDXL, SDXLRefiner from modules.sample_hijack import clip_separate +from modules.util import get_file_from_folder_list model_base = core.StableDiffusionModel() @@ -60,7 +61,7 @@ def assert_model_integrity(): def refresh_base_model(name): global model_base - filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name))) + filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) if model_base.filename == filename: return @@ -76,7 +77,7 @@ def refresh_base_model(name): def refresh_refiner_model(name): global model_refiner - filename = os.path.abspath(os.path.realpath(os.path.join(modules.config.path_checkpoints, name))) + filename = get_file_from_folder_list(name, modules.config.paths_checkpoints) if model_refiner.filename == filename: return diff --git a/modules/util.py b/modules/util.py index 9d4d0996..3c23a992 100644 --- a/modules/util.py +++ b/modules/util.py @@ -177,5 +177,14 @@ def get_files_from_folder(folder_path, exensions=None, name_filter=None): return filenames +def get_file_from_folder_list(name, folders): + for folder in folders: + filename = os.path.abspath(os.path.realpath(os.path.join(folder, name))) + if os.path.isfile(filename): + return filename + + return os.path.abspath(os.path.realpath(os.path.join(folders[0], name))) + + def ordinal_suffix(number: int) -> str: return 'th' if 10 <= number % 100 <= 20 else {1: 'st', 2: 'nd', 3: 'rd'}.get(number % 10, 'th')