maybe make like this

This commit is contained in:
nabilaba 2025-06-09 13:43:44 +07:00
parent 475527f0ee
commit 31fcbbb76a
1 changed files with 17 additions and 35 deletions

View File

@ -46,12 +46,13 @@ def refresh_controlnets(model_paths):
@torch.no_grad()
@torch.inference_mode()
def assert_model_integrity():
if model_base.unet_with_lora is None or not hasattr(model_base.unet_with_lora, "model"):
print('[Info] Skipping model integrity check: base model is not loaded.')
return True
error_message = None
if not isinstance(model_base.unet_with_lora.model, SDXL):
raise NotImplementedError('You have selected base model other than SDXL. This is not supported yet.')
error_message = 'You have selected base model other than SDXL. This is not supported yet.'
if error_message is not None:
raise NotImplementedError(error_message)
return True
@ -61,11 +62,6 @@ def assert_model_integrity():
def refresh_base_model(name, vae_name=None):
global model_base
if name is None or name == 'None':
print('[Info] No base model loaded.')
model_base = core.StableDiffusionModel()
return
filename = get_file_from_folder_list(name, modules.config.paths_checkpoints)
vae_filename = None
@ -219,34 +215,17 @@ def set_clip_skip(clip_skip: int):
@torch.no_grad()
@torch.inference_mode()
def clear_all_caches():
global final_clip
if final_clip is not None and hasattr(final_clip, "fcs_cond_cache"):
final_clip.fcs_cond_cache = {}
else:
print("[Info] Skipping cache clear: final_clip is None.")
final_clip.fcs_cond_cache = {}
@torch.no_grad()
@torch.inference_mode()
def prepare_text_encoder(async_call=True):
if async_call:
# TODO: make sure that this is always called in an async way so that users cannot feel it.
pass
assert_model_integrity()
patchers = []
if final_clip is not None and hasattr(final_clip, "patcher"):
patchers.append(final_clip.patcher)
if final_expansion is not None and hasattr(final_expansion, "patcher"):
patchers.append(final_expansion.patcher)
if len(patchers) > 0:
ldm_patched.modules.model_management.load_models_gpu(patchers)
else:
print("[Info] No models to load into GPU (no base model).")
ldm_patched.modules.model_management.load_models_gpu([final_clip.patcher, final_expansion.patcher])
return
@ -288,12 +267,15 @@ def refresh_everything(refiner_model_name, base_model_name, loras,
return
refresh_everything(
refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.config.default_base_model_name,
loras=get_enabled_loras(modules.config.default_loras),
vae_name=modules.config.default_vae,
)
if modules.config.default_base_model_name != 'None':
refresh_everything(
refiner_model_name=modules.config.default_refiner_model_name,
base_model_name=modules.config.default_base_model_name,
loras=get_enabled_loras(modules.config.default_loras),
vae_name=modules.config.default_vae,
)
else:
print('[Startup] Skipping model load (default_base_model_name is "None").')
@torch.no_grad()