fix lowvram (#563)

* fix lowvram

* fix lowvram
This commit is contained in:
lllyasviel 2023-10-08 01:09:10 -07:00 committed by GitHub
parent 6c03faf568
commit 51cdc5e53a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 33 additions and 46 deletions

View File

@ -115,10 +115,11 @@ ip_negative: torch.Tensor = None
image_proj_model: ModelPatcher = None
ip_layers: ModelPatcher = None
ip_adapter: IPAdapterModel = None
ip_unconds = None
def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter
global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter, ip_unconds
if clip_vision_path is None:
return
@ -168,14 +169,17 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
ip_layers = ModelPatcher(model=ip_adapter.ip_layers, load_device=load_device,
offload_device=offload_device)
ip_unconds = None
return
@torch.no_grad()
@torch.inference_mode()
def preprocess(img):
global ip_unconds
inputs = clip_vision.processor(images=img, return_tensors="pt")
comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model])
comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model, ip_layers])
pixel_values = inputs['pixel_values'].to(clip_vision.load_device)
if clip_vision.dtype != torch.float32:
@ -191,9 +195,13 @@ def preprocess(img):
else:
cond = outputs.image_embeds.to(ip_adapter.dtype)
outputs = image_proj_model.model(cond)
if ip_unconds is None:
uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_unconds = [m(uncond).cpu() for m in ip_layers.model.to_kvs]
return outputs
cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_conds = [m(cond).cpu() for m in ip_layers.model.to_kvs]
return ip_conds
@torch.no_grad()
@ -206,65 +214,44 @@ def patch_model(model, ip_tasks):
tasks.append((cn_img, cn_stop, cn_weight, {}))
def make_attn_patcher(ip_index):
ip_model_k = ip_layers.model.to_kvs[ip_index * 2]
ip_model_v = ip_layers.model.to_kvs[ip_index * 2 + 1]
def patcher(n, context_attn2, value_attn2, extra_options):
org_dtype = n.dtype
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
cond_or_uncond = extra_options['cond_or_uncond']
batch_size = int(context_attn2.shape[0])
with torch.autocast("cuda", dtype=ip_adapter.dtype):
q = n
k = [context_attn2]
v = [value_attn2]
b, _, _ = q.shape
batch_prompt = b // len(cond_or_uncond)
for cn_img, cn_stop, cn_weight, cache in tasks:
for ip_conds, cn_stop, cn_weight, cache in tasks:
if current_step < cn_stop:
ip_k, ip_v = None, None
ip_k_c = ip_conds[ip_index * 2].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_v_c = ip_conds[ip_index * 2 + 1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_k_uc = ip_unconds[ip_index * 2].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_v_uc = ip_unconds[ip_index * 2 + 1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
if ip_index in cache:
ip_k, ip_v = cache[ip_index]
if int(ip_k.shape[0]) != batch_size:
ip_k = None
if int(ip_v.shape[0]) != batch_size:
ip_v = None
ip_k = torch.cat([(ip_k_c, ip_k_uc)[i] for i in cond_or_uncond], dim=0)
ip_v = torch.cat([(ip_v_c, ip_v_uc)[i] for i in cond_or_uncond], dim=0)
if ip_k is None or ip_v is None:
ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1)
uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1)
uncond_cond = torch.cat([(cond, uncond)[i] for i in cond_or_uncond], dim=0)
ip_k = ip_model_k(uncond_cond)
ip_v = ip_model_v(uncond_cond)
# Midjourney's attention formulation of image prompt (non-official reimplementation)
# Written by Lvmin Zhang at Stanford University, 2023 Dec
# For non-commercial use only - if you use this in commercial project then
# probably it has some intellectual property issues.
# Contact lvminzhang@acm.org if you are not sure.
# Midjourney's attention formulation of image prompt (non-official reimplementation)
# Written by Lvmin Zhang at Stanford University, 2023 Dec
# For non-commercial use only - if you use this in commercial project then
# probably it has some intellectual property issues.
# Contact lvminzhang@acm.org if you are not sure.
# Below is the sensitive part with potential intellectual property issues.
# Below is the sensitive part with potential intellectual property issues.
ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True)
ip_v_offset = ip_v - ip_v_mean
ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True)
ip_v_offset = ip_v - ip_v_mean
B, F, C = ip_k.shape
channel_penalty = float(C) / 1280.0
weight = cn_weight * channel_penalty
B, F, C = ip_k.shape
channel_penalty = float(C) / 1280.0
weight = cn_weight * channel_penalty
ip_k = ip_k * weight
ip_v = ip_v_offset + ip_v_mean * weight
# The sensitive part ends here.
cache[ip_index] = ip_k, ip_v
ip_model_k.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype)
ip_model_v.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype)
ip_k = ip_k * weight
ip_v = ip_v_offset + ip_v_mean * weight
k.append(ip_k)
v.append(ip_v)

View File

@ -1 +1 @@
version = '2.1.2'
version = '2.1.3'