diff --git a/fooocus_version.py b/fooocus_version.py index d5e7485f..ed14fcbe 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.694' +version = '2.1.695' diff --git a/modules/core.py b/modules/core.py index e83be00f..d7c2d650 100644 --- a/modules/core.py +++ b/modules/core.py @@ -256,6 +256,7 @@ def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sa finally: modules.sample_hijack.current_refiner = None + modules.sample_hijack.force_unload_all_control(positive, negative) return out diff --git a/modules/sample_hijack.py b/modules/sample_hijack.py index eb6cfdd1..f3853eb2 100644 --- a/modules/sample_hijack.py +++ b/modules/sample_hijack.py @@ -3,7 +3,7 @@ import fcbh.samplers import fcbh.model_management from fcbh.model_base import SDXLRefiner, SDXL -from fcbh.sample import get_additional_models +from fcbh.sample import get_additional_models, get_models_from_cond from fcbh.samplers import resolve_areas_and_cond_masks, wrap_model, calculate_start_end_timesteps, \ create_cond_with_same_area_if_none, pre_run_control, apply_empty_x_to_equal_area, encode_adm, \ blank_inpaint_image_like @@ -49,6 +49,22 @@ def clip_separate(cond, target_model=None, target_clip=None): return [[c, p]] +@torch.no_grad() +@torch.inference_mode() +def force_unload_all_control(positive, negative): + control_nets = set(get_models_from_cond(positive, "control") + get_models_from_cond(negative, "control")) + + cleaned_any_model = False + for m in control_nets: + if hasattr(m, 'cleanup'): + m.cleanup() + cleaned_any_model = True + + if cleaned_any_model: + fcbh.model_management.soft_empty_cache() + return + + @torch.no_grad() @torch.inference_mode() def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas, model_options={}, latent_image=None, denoise_mask=None, callback=None, disable_pbar=False, seed=None): @@ -113,6 +129,8 @@ def sample_hacked(model, noise, positive, negative, cfg, device, sampler, sigmas extra_args["cond_concat"] = cond_concat def refiner_switch(): + force_unload_all_control(positive, negative) + extra_args["cond"] = positive_refiner extra_args["uncond"] = negative_refiner