diff --git a/fooocus_version.py b/fooocus_version.py index 82ccd38c..a63055f8 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.0.14' +version = '2.0.16' diff --git a/modules/async_worker.py b/modules/async_worker.py index 5f1e0797..43dfcc90 100644 --- a/modules/async_worker.py +++ b/modules/async_worker.py @@ -16,6 +16,7 @@ def worker(): import modules.default_pipeline as pipeline import modules.path import modules.patch + import modules.virtual_memory as virtual_memory from modules.sdxl_styles import apply_style, aspect_ratios, fooocus_expansion from modules.private_logger import log @@ -80,10 +81,10 @@ def worker(): progressbar(3, 'Loading models ...') - pipeline.refresh_base_model(base_model_name) - pipeline.refresh_refiner_model(refiner_model_name) - pipeline.refresh_loras(loras) - pipeline.clear_all_caches() + pipeline.refresh_everything( + refiner_model_name=refiner_model_name, + base_model_name=base_model_name, + loras=loras) progressbar(3, 'Processing prompts ...') @@ -137,6 +138,8 @@ def worker(): pool_top_k=negative_top_k) if pipeline.xl_refiner is not None: + virtual_memory.load_from_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model) + for i, t in enumerate(tasks): progressbar(11, f'Encoding refiner positive #{i + 1} ...') t['c'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['positive'], @@ -147,6 +150,8 @@ def worker(): t['uc'][1] = pipeline.clip_encode(sd=pipeline.xl_refiner, texts=t['negative'], pool_top_k=negative_top_k) + virtual_memory.try_move_to_virtual_memory(pipeline.xl_refiner.clip.cond_stage_model) + if performance_selction == 'Speed': steps = 30 switch = 20 diff --git a/modules/core.py b/modules/core.py index 16fde865..d43454aa 100644 --- a/modules/core.py +++ b/modules/core.py @@ -10,6 +10,7 @@ import comfy.utils from comfy.sd import load_checkpoint_guess_config from nodes import VAEDecode, EmptyLatentImage from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models +from comfy.model_base import SDXLRefiner from modules.samplers_advanced import KSampler, KSamplerWithRefiner from modules.patch import patch_all @@ -20,7 +21,15 @@ opVAEDecode = VAEDecode() class StableDiffusionModel: - def __init__(self, unet, vae, clip, clip_vision): + def __init__(self, unet, vae, clip, clip_vision, model_filename=None): + if isinstance(model_filename, str): + is_refiner = isinstance(unet.model, SDXLRefiner) + if unet is not None: + unet.model.model_file = dict(filename=model_filename, prefix='model') + if clip is not None: + clip.cond_stage_model.model_file = dict(filename=model_filename, prefix='refiner_clip' if is_refiner else 'base_clip') + if vae is not None: + vae.first_stage_model.model_file = dict(filename=model_filename, prefix='first_stage_model') self.unet = unet self.vae = vae self.clip = clip @@ -38,7 +47,7 @@ class StableDiffusionModel: @torch.no_grad() def load_model(ckpt_filename): unet, clip, vae, clip_vision = load_checkpoint_guess_config(ckpt_filename) - return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision) + return StableDiffusionModel(unet=unet, clip=clip, vae=vae, clip_vision=clip_vision, model_filename=ckpt_filename) @torch.no_grad() diff --git a/modules/default_pipeline.py b/modules/default_pipeline.py index bf2cc077..a374c79a 100644 --- a/modules/default_pipeline.py +++ b/modules/default_pipeline.py @@ -2,6 +2,7 @@ import modules.core as core import os import torch import modules.path +import modules.virtual_memory as virtual_memory import comfy.model_management as model_management from comfy.model_base import SDXL, SDXLRefiner @@ -21,10 +22,12 @@ xl_base_patched_hash = '' def refresh_base_model(name): global xl_base, xl_base_hash, xl_base_patched, xl_base_patched_hash - if xl_base_hash == str(name): - return - filename = os.path.join(modules.path.modelfile_path, name) + filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name))) + model_hash = filename + + if xl_base_hash == model_hash: + return if xl_base is not None: xl_base.to_meta() @@ -36,21 +39,25 @@ def refresh_base_model(name): xl_base = None xl_base_hash = '' refresh_base_model(modules.path.default_base_model_name) - xl_base_hash = name + xl_base_hash = model_hash xl_base_patched = xl_base xl_base_patched_hash = '' return - xl_base_hash = name + xl_base_hash = model_hash xl_base_patched = xl_base xl_base_patched_hash = '' - print(f'Base model loaded: {xl_base_hash}') + print(f'Base model loaded: {model_hash}') return def refresh_refiner_model(name): global xl_refiner, xl_refiner_hash - if xl_refiner_hash == str(name): + + filename = os.path.abspath(os.path.realpath(os.path.join(modules.path.modelfile_path, name))) + model_hash = filename + + if xl_refiner_hash == model_hash: return if name == 'None': @@ -59,8 +66,6 @@ def refresh_refiner_model(name): print(f'Refiner unloaded.') return - filename = os.path.join(modules.path.modelfile_path, name) - if xl_refiner is not None: xl_refiner.to_meta() xl_refiner = None @@ -73,8 +78,8 @@ def refresh_refiner_model(name): print(f'Refiner unloaded.') return - xl_refiner_hash = name - print(f'Refiner model loaded: {xl_refiner_hash}') + xl_refiner_hash = model_hash + print(f'Refiner model loaded: {model_hash}') xl_refiner.vae.first_stage_model.to('meta') xl_refiner.vae = None @@ -100,13 +105,6 @@ def refresh_loras(loras): return -refresh_base_model(modules.path.default_base_model_name) -refresh_refiner_model(modules.path.default_refiner_model_name) -refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)]) - -expansion = FooocusExpansion() - - @torch.no_grad() def clip_encode_single(clip, text, verbose=False): cached = clip.fcs_cond_cache.get(text, None) @@ -133,8 +131,6 @@ def clip_encode(sd, texts, pool_top_k=1): if len(texts) == 0: return None - model_management.soft_empty_cache() - clip = sd.clip cond_list = [] pooled_acc = 0 @@ -164,6 +160,29 @@ def clear_all_caches(): clear_sd_cond_cache(xl_refiner) +def refresh_everything(refiner_model_name, base_model_name, loras): + refresh_refiner_model(refiner_model_name) + if xl_refiner is not None: + virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model) + virtual_memory.try_move_to_virtual_memory(xl_refiner.clip.cond_stage_model) + + refresh_base_model(base_model_name) + virtual_memory.load_from_virtual_memory(xl_base.unet.model) + + refresh_loras(loras) + clear_all_caches() + return + + +refresh_everything( + refiner_model_name=modules.path.default_refiner_model_name, + base_model_name=modules.path.default_base_model_name, + loras=[(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5), ('None', 0.5)] +) + +expansion = FooocusExpansion() + + @torch.no_grad() def patch_all_models(): assert xl_base is not None @@ -181,7 +200,10 @@ def patch_all_models(): @torch.no_grad() def process_diffusion(positive_cond, negative_cond, steps, switch, width, height, image_seed, callback): patch_all_models() - model_management.soft_empty_cache() + + if xl_refiner is not None: + virtual_memory.try_move_to_virtual_memory(xl_refiner.unet.model) + virtual_memory.load_from_virtual_memory(xl_base.unet.model) empty_latent = core.generate_empty_latent(width=width, height=height, batch_size=1) diff --git a/modules/samplers_advanced.py b/modules/samplers_advanced.py index 6ad77343..2d382f64 100644 --- a/modules/samplers_advanced.py +++ b/modules/samplers_advanced.py @@ -1,6 +1,7 @@ from comfy.samplers import * import comfy.model_management +import modules.virtual_memory class KSamplerWithRefiner: @@ -152,6 +153,8 @@ class KSamplerWithRefiner: noise.shape[3], noise.shape[2], self.device, "negative") def refiner_switch(): + modules.virtual_memory.try_move_to_virtual_memory(self.model_denoise.inner_model) + modules.virtual_memory.load_from_virtual_memory(self.refiner_model_denoise.inner_model) comfy.model_management.load_model_gpu(self.refiner_model_patcher) self.model_denoise.inner_model = self.refiner_model_denoise.inner_model for i in range(len(positive)): diff --git a/modules/virtual_memory.py b/modules/virtual_memory.py new file mode 100644 index 00000000..eea282ac --- /dev/null +++ b/modules/virtual_memory.py @@ -0,0 +1,175 @@ +import torch +import gc + +from safetensors import safe_open +from comfy import model_management +from comfy.diffusers_convert import textenc_conversion_lst + + +ALWAYS_USE_VM = None + +if ALWAYS_USE_VM is not None: + print(f'[Virtual Memory System] Forced = {ALWAYS_USE_VM}') + +if 'cpu' in model_management.unet_offload_device().type.lower(): + logic_memory = model_management.total_ram + global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 30000 + print(f'[Virtual Memory System] Logic target is CPU, memory = {logic_memory}') +else: + logic_memory = model_management.total_vram + global_virtual_memory_activated = ALWAYS_USE_VM if ALWAYS_USE_VM is not None else logic_memory < 22000 + print(f'[Virtual Memory System] Logic target is GPU, memory = {logic_memory}') + + +print(f'[Virtual Memory System] Activated = {global_virtual_memory_activated}') + + +@torch.no_grad() +def recursive_set(obj, key, value): + if obj is None: + return + if '.' in key: + k1, k2 = key.split('.', 1) + recursive_set(getattr(obj, k1, None), k2, value) + else: + setattr(obj, key, value) + + +@torch.no_grad() +def recursive_del(obj, key): + if obj is None: + return + if '.' in key: + k1, k2 = key.split('.', 1) + recursive_del(getattr(obj, k1, None), k2) + else: + delattr(obj, key) + + +@torch.no_grad() +def force_load_state_dict(model, state_dict): + for k in list(state_dict.keys()): + p = torch.nn.Parameter(state_dict[k], requires_grad=False) + recursive_set(model, k, p) + del state_dict[k] + return + + +@torch.no_grad() +def only_load_safetensors_keys(filename): + try: + with safe_open(filename, framework="pt", device='cpu') as f: + result = list(f.keys()) + assert len(result) > 0 + return result + except: + return None + + +@torch.no_grad() +def move_to_virtual_memory(model, comfy_unload=True): + if comfy_unload: + model_management.unload_model() + + virtual_memory_dict = getattr(model, 'virtual_memory_dict', None) + if isinstance(virtual_memory_dict, dict): + # Already in virtual memory. + return + + model_file = getattr(model, 'model_file', None) + assert isinstance(model_file, dict) + + filename = model_file['filename'] + prefix = model_file['prefix'] + + safetensors_keys = only_load_safetensors_keys(filename) + + if safetensors_keys is None: + print(f'[Virtual Memory System] Error: The Virtual Memory System currently only support safetensors models!') + return + + sd = model.state_dict() + original_device = list(sd.values())[0].device.type + model_file['original_device'] = original_device + + virtual_memory_dict = {} + + for k, v in sd.items(): + current_key = k + current_flag = None + if prefix == 'refiner_clip': + current_key_in_safetensors = k + + for a, b in textenc_conversion_lst: + current_key_in_safetensors = current_key_in_safetensors.replace(b, a) + + current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.transformer.text_model.encoder.layers.', 'conditioner.embedders.0.model.transformer.resblocks.') + current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.text_projection', 'conditioner.embedders.0.model.text_projection') + current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.logit_scale', 'conditioner.embedders.0.model.logit_scale') + current_key_in_safetensors = current_key_in_safetensors.replace('clip_g.', 'conditioner.embedders.0.model.') + + for e in ["weight", "bias"]: + for i, k in enumerate(['q', 'k', 'v']): + e_flag = f'.{k}_proj.{e}' + if current_key_in_safetensors.endswith(e_flag): + current_key_in_safetensors = current_key_in_safetensors[:-len(e_flag)] + f'.in_proj_{e}' + current_flag = (1280 * i, 1280 * (i + 1)) + else: + current_key_in_safetensors = prefix + '.' + k + current_device = torch.device(index=v.device.index, type=v.device.type) + if current_key_in_safetensors in safetensors_keys: + virtual_memory_dict[current_key] = (current_key_in_safetensors, current_device, current_flag) + recursive_del(model, current_key) + else: + # print(f'[Virtual Memory System] Missed key: {current_key}') + pass + + del sd + gc.collect() + model_management.soft_empty_cache() + + model.virtual_memory_dict = virtual_memory_dict + print(f'[Virtual Memory System] {prefix} released from {original_device}: {filename}') + return + + +@torch.no_grad() +def load_from_virtual_memory(model): + virtual_memory_dict = getattr(model, 'virtual_memory_dict', None) + if not isinstance(virtual_memory_dict, dict): + # Not in virtual memory. + return + + model_file = getattr(model, 'model_file', None) + assert isinstance(model_file, dict) + + filename = model_file['filename'] + prefix = model_file['prefix'] + original_device = model_file['original_device'] + + with safe_open(filename, framework="pt", device=original_device) as f: + for current_key, (current_key_in_safetensors, current_device, current_flag) in virtual_memory_dict.items(): + tensor = f.get_tensor(current_key_in_safetensors).to(current_device) + if isinstance(current_flag, tuple) and len(current_flag) == 2: + a, b = current_flag + tensor = tensor[a:b] + parameter = torch.nn.Parameter(tensor, requires_grad=False) + recursive_set(model, current_key, parameter) + + print(f'[Virtual Memory System] {prefix} loaded to {original_device}: {filename}') + del model.virtual_memory_dict + return + + +@torch.no_grad() +def try_move_to_virtual_memory(model, comfy_unload=True): + if not global_virtual_memory_activated: + return + + import modules.default_pipeline as pipeline + + if pipeline.xl_refiner is None: + # If users do not use refiner, no need to use this. + return + + move_to_virtual_memory(model, comfy_unload) diff --git a/update_log.md b/update_log.md index 1c96796f..25b14efe 100644 --- a/update_log.md +++ b/update_log.md @@ -1,3 +1,9 @@ +### 2.0.16 + +* Virtual memory system implemented. Now Colab can run both base model and refiner model with 7.8GB RAM + 5.3GB VRAM, and it never crashes. +* If you are lucky enough to read this line, keep in mind that ComfyUI cannot do this. This is very reasonable that Fooocus is more optimized because it only need to handle a fixed pipeline, but ComfyUI need to consider arbitrary pipelines. +* But if we just consider the optimization of this fixed workload, after 2.0.16, Fooocus has become the most optimized SDXL app, outperforming ComfyUI. + ### 2.0.0 * V2 released.