try fix lora vram (#357)

* try fix lora vram

* try fix lora vram
This commit is contained in:
lllyasviel 2023-09-13 02:29:43 -07:00 committed by GitHub
parent eccf32b78c
commit a9b7219604
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 12 additions and 8 deletions

View File

@ -1 +1 @@
version = '2.0.0'
version = '2.0.1'

View File

@ -8,14 +8,13 @@ import comfy.model_management
import comfy.utils
from comfy.sd import load_checkpoint_guess_config
from nodes import VAEDecode, EmptyLatentImage, CLIPTextEncode
from nodes import VAEDecode, EmptyLatentImage
from comfy.sample import prepare_mask, broadcast_cond, load_additional_models, cleanup_additional_models
from modules.samplers_advanced import KSampler, KSamplerWithRefiner
from modules.patch import patch_all
patch_all()
opCLIPTextEncode = CLIPTextEncode()
opEmptyLatentImage = EmptyLatentImage()
opVAEDecode = VAEDecode()
@ -52,11 +51,6 @@ def load_lora(model, lora_filename, strength_model=1.0, strength_clip=1.0):
return StableDiffusionModel(unet=unet, clip=clip, vae=model.vae, clip_vision=model.clip_vision)
@torch.no_grad()
def encode_prompt_condition(clip, prompt):
return opCLIPTextEncode.encode(clip=clip, text=prompt)[0]
@torch.no_grad()
def generate_empty_latent(width=1024, height=1024, batch_size=1):
return opEmptyLatentImage.generate(width=width, height=height, batch_size=batch_size)[0]

View File

@ -106,6 +106,7 @@ refresh_loras([(modules.path.default_lora_name, 0.5), ('None', 0.5), ('None', 0.
expansion = FooocusExpansion()
@torch.no_grad()
def clip_encode_single(clip, text, verbose=False):
cached = clip.fcs_cond_cache.get(text, None)
if cached is not None:
@ -120,6 +121,7 @@ def clip_encode_single(clip, text, verbose=False):
return result
@torch.no_grad()
def clip_encode(sd, texts, pool_top_k=1):
if sd is None:
return None
@ -143,6 +145,7 @@ def clip_encode(sd, texts, pool_top_k=1):
return [[torch.cat(cond_list, dim=1), {"pooled_output": pooled_acc}]]
@torch.no_grad()
def clear_sd_cond_cache(sd):
if sd is None:
return None
@ -152,6 +155,7 @@ def clear_sd_cond_cache(sd):
return
@torch.no_grad()
def clear_all_caches():
clear_sd_cond_cache(xl_base_patched)
clear_sd_cond_cache(xl_refiner)

View File

@ -3,6 +3,7 @@ import comfy.model_base
import comfy.ldm.modules.diffusionmodules.openaimodel
import comfy.samplers
import comfy.k_diffusion.external
import comfy.model_management
import modules.anisotropic as anisotropic
from comfy.k_diffusion import utils
@ -68,6 +69,11 @@ def sdxl_encode_adm_patched(self, **kwargs):
return torch.cat((clip_pooled.to(flat.device), flat), dim=1)
def text_encoder_device_patched():
return torch.device("cpu")
def patch_all():
# comfy.model_management.text_encoder_device = text_encoder_device_patched
comfy.k_diffusion.external.DiscreteEpsDDPMDenoiser.forward = patched_discrete_eps_ddpm_denoiser_forward
comfy.model_base.SDXL.encode_adm = sdxl_encode_adm_patched