parent
6c03faf568
commit
51cdc5e53a
|
|
@ -115,10 +115,11 @@ ip_negative: torch.Tensor = None
|
|||
image_proj_model: ModelPatcher = None
|
||||
ip_layers: ModelPatcher = None
|
||||
ip_adapter: IPAdapterModel = None
|
||||
ip_unconds = None
|
||||
|
||||
|
||||
def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
|
||||
global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter
|
||||
global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter, ip_unconds
|
||||
|
||||
if clip_vision_path is None:
|
||||
return
|
||||
|
|
@ -168,14 +169,17 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
|
|||
ip_layers = ModelPatcher(model=ip_adapter.ip_layers, load_device=load_device,
|
||||
offload_device=offload_device)
|
||||
|
||||
ip_unconds = None
|
||||
return
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
@torch.inference_mode()
|
||||
def preprocess(img):
|
||||
global ip_unconds
|
||||
|
||||
inputs = clip_vision.processor(images=img, return_tensors="pt")
|
||||
comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model])
|
||||
comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model, ip_layers])
|
||||
pixel_values = inputs['pixel_values'].to(clip_vision.load_device)
|
||||
|
||||
if clip_vision.dtype != torch.float32:
|
||||
|
|
@ -191,9 +195,13 @@ def preprocess(img):
|
|||
else:
|
||||
cond = outputs.image_embeds.to(ip_adapter.dtype)
|
||||
|
||||
outputs = image_proj_model.model(cond)
|
||||
if ip_unconds is None:
|
||||
uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_unconds = [m(uncond).cpu() for m in ip_layers.model.to_kvs]
|
||||
|
||||
return outputs
|
||||
cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_conds = [m(cond).cpu() for m in ip_layers.model.to_kvs]
|
||||
return ip_conds
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
|
|
@ -206,65 +214,44 @@ def patch_model(model, ip_tasks):
|
|||
tasks.append((cn_img, cn_stop, cn_weight, {}))
|
||||
|
||||
def make_attn_patcher(ip_index):
|
||||
ip_model_k = ip_layers.model.to_kvs[ip_index * 2]
|
||||
ip_model_v = ip_layers.model.to_kvs[ip_index * 2 + 1]
|
||||
|
||||
def patcher(n, context_attn2, value_attn2, extra_options):
|
||||
org_dtype = n.dtype
|
||||
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
|
||||
cond_or_uncond = extra_options['cond_or_uncond']
|
||||
batch_size = int(context_attn2.shape[0])
|
||||
|
||||
with torch.autocast("cuda", dtype=ip_adapter.dtype):
|
||||
q = n
|
||||
k = [context_attn2]
|
||||
v = [value_attn2]
|
||||
b, _, _ = q.shape
|
||||
batch_prompt = b // len(cond_or_uncond)
|
||||
|
||||
for cn_img, cn_stop, cn_weight, cache in tasks:
|
||||
for ip_conds, cn_stop, cn_weight, cache in tasks:
|
||||
if current_step < cn_stop:
|
||||
ip_k, ip_v = None, None
|
||||
ip_k_c = ip_conds[ip_index * 2].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_v_c = ip_conds[ip_index * 2 + 1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_k_uc = ip_unconds[ip_index * 2].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_v_uc = ip_unconds[ip_index * 2 + 1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
|
||||
if ip_index in cache:
|
||||
ip_k, ip_v = cache[ip_index]
|
||||
if int(ip_k.shape[0]) != batch_size:
|
||||
ip_k = None
|
||||
if int(ip_v.shape[0]) != batch_size:
|
||||
ip_v = None
|
||||
ip_k = torch.cat([(ip_k_c, ip_k_uc)[i] for i in cond_or_uncond], dim=0)
|
||||
ip_v = torch.cat([(ip_v_c, ip_v_uc)[i] for i in cond_or_uncond], dim=0)
|
||||
|
||||
if ip_k is None or ip_v is None:
|
||||
ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1)
|
||||
uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1)
|
||||
uncond_cond = torch.cat([(cond, uncond)[i] for i in cond_or_uncond], dim=0)
|
||||
ip_k = ip_model_k(uncond_cond)
|
||||
ip_v = ip_model_v(uncond_cond)
|
||||
# Midjourney's attention formulation of image prompt (non-official reimplementation)
|
||||
# Written by Lvmin Zhang at Stanford University, 2023 Dec
|
||||
# For non-commercial use only - if you use this in commercial project then
|
||||
# probably it has some intellectual property issues.
|
||||
# Contact lvminzhang@acm.org if you are not sure.
|
||||
|
||||
# Midjourney's attention formulation of image prompt (non-official reimplementation)
|
||||
# Written by Lvmin Zhang at Stanford University, 2023 Dec
|
||||
# For non-commercial use only - if you use this in commercial project then
|
||||
# probably it has some intellectual property issues.
|
||||
# Contact lvminzhang@acm.org if you are not sure.
|
||||
# Below is the sensitive part with potential intellectual property issues.
|
||||
|
||||
# Below is the sensitive part with potential intellectual property issues.
|
||||
ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True)
|
||||
ip_v_offset = ip_v - ip_v_mean
|
||||
|
||||
ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True)
|
||||
ip_v_offset = ip_v - ip_v_mean
|
||||
B, F, C = ip_k.shape
|
||||
channel_penalty = float(C) / 1280.0
|
||||
weight = cn_weight * channel_penalty
|
||||
|
||||
B, F, C = ip_k.shape
|
||||
channel_penalty = float(C) / 1280.0
|
||||
weight = cn_weight * channel_penalty
|
||||
|
||||
ip_k = ip_k * weight
|
||||
ip_v = ip_v_offset + ip_v_mean * weight
|
||||
|
||||
# The sensitive part ends here.
|
||||
|
||||
cache[ip_index] = ip_k, ip_v
|
||||
ip_model_k.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype)
|
||||
ip_model_v.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype)
|
||||
ip_k = ip_k * weight
|
||||
ip_v = ip_v_offset + ip_v_mean * weight
|
||||
|
||||
k.append(ip_k)
|
||||
v.append(ip_v)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
version = '2.1.2'
|
||||
version = '2.1.3'
|
||||
|
|
|
|||
Loading…
Reference in New Issue