diff --git a/fooocus_extras/ip_adapter.py b/fooocus_extras/ip_adapter.py index 28d04e45..453afa32 100644 --- a/fooocus_extras/ip_adapter.py +++ b/fooocus_extras/ip_adapter.py @@ -115,10 +115,11 @@ ip_negative: torch.Tensor = None image_proj_model: ModelPatcher = None ip_layers: ModelPatcher = None ip_adapter: IPAdapterModel = None +ip_unconds = None def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path): - global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter + global clip_vision, image_proj_model, ip_layers, ip_negative, ip_adapter, ip_unconds if clip_vision_path is None: return @@ -168,14 +169,17 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path): ip_layers = ModelPatcher(model=ip_adapter.ip_layers, load_device=load_device, offload_device=offload_device) + ip_unconds = None return @torch.no_grad() @torch.inference_mode() def preprocess(img): + global ip_unconds + inputs = clip_vision.processor(images=img, return_tensors="pt") - comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model]) + comfy.model_management.load_models_gpu([clip_vision.patcher, image_proj_model, ip_layers]) pixel_values = inputs['pixel_values'].to(clip_vision.load_device) if clip_vision.dtype != torch.float32: @@ -191,9 +195,13 @@ def preprocess(img): else: cond = outputs.image_embeds.to(ip_adapter.dtype) - outputs = image_proj_model.model(cond) + if ip_unconds is None: + uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + ip_unconds = [m(uncond).cpu() for m in ip_layers.model.to_kvs] - return outputs + cond = image_proj_model.model(cond).to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + ip_conds = [m(cond).cpu() for m in ip_layers.model.to_kvs] + return ip_conds @torch.no_grad() @@ -206,65 +214,44 @@ def patch_model(model, ip_tasks): tasks.append((cn_img, cn_stop, cn_weight, {})) def make_attn_patcher(ip_index): - ip_model_k = ip_layers.model.to_kvs[ip_index * 2] - ip_model_v = ip_layers.model.to_kvs[ip_index * 2 + 1] - def patcher(n, context_attn2, value_attn2, extra_options): org_dtype = n.dtype current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0]) cond_or_uncond = extra_options['cond_or_uncond'] - batch_size = int(context_attn2.shape[0]) with torch.autocast("cuda", dtype=ip_adapter.dtype): q = n k = [context_attn2] v = [value_attn2] b, _, _ = q.shape - batch_prompt = b // len(cond_or_uncond) - for cn_img, cn_stop, cn_weight, cache in tasks: + for ip_conds, cn_stop, cn_weight, cache in tasks: if current_step < cn_stop: - ip_k, ip_v = None, None + ip_k_c = ip_conds[ip_index * 2].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + ip_v_c = ip_conds[ip_index * 2 + 1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + ip_k_uc = ip_unconds[ip_index * 2].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) + ip_v_uc = ip_unconds[ip_index * 2 + 1].to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) - if ip_index in cache: - ip_k, ip_v = cache[ip_index] - if int(ip_k.shape[0]) != batch_size: - ip_k = None - if int(ip_v.shape[0]) != batch_size: - ip_v = None + ip_k = torch.cat([(ip_k_c, ip_k_uc)[i] for i in cond_or_uncond], dim=0) + ip_v = torch.cat([(ip_v_c, ip_v_uc)[i] for i in cond_or_uncond], dim=0) - if ip_k is None or ip_v is None: - ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) - ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) - cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1) - uncond = ip_negative.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1) - uncond_cond = torch.cat([(cond, uncond)[i] for i in cond_or_uncond], dim=0) - ip_k = ip_model_k(uncond_cond) - ip_v = ip_model_v(uncond_cond) + # Midjourney's attention formulation of image prompt (non-official reimplementation) + # Written by Lvmin Zhang at Stanford University, 2023 Dec + # For non-commercial use only - if you use this in commercial project then + # probably it has some intellectual property issues. + # Contact lvminzhang@acm.org if you are not sure. - # Midjourney's attention formulation of image prompt (non-official reimplementation) - # Written by Lvmin Zhang at Stanford University, 2023 Dec - # For non-commercial use only - if you use this in commercial project then - # probably it has some intellectual property issues. - # Contact lvminzhang@acm.org if you are not sure. + # Below is the sensitive part with potential intellectual property issues. - # Below is the sensitive part with potential intellectual property issues. + ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) + ip_v_offset = ip_v - ip_v_mean - ip_v_mean = torch.mean(ip_v, dim=1, keepdim=True) - ip_v_offset = ip_v - ip_v_mean + B, F, C = ip_k.shape + channel_penalty = float(C) / 1280.0 + weight = cn_weight * channel_penalty - B, F, C = ip_k.shape - channel_penalty = float(C) / 1280.0 - weight = cn_weight * channel_penalty - - ip_k = ip_k * weight - ip_v = ip_v_offset + ip_v_mean * weight - - # The sensitive part ends here. - - cache[ip_index] = ip_k, ip_v - ip_model_k.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype) - ip_model_v.to(device=ip_adapter.offload_device, dtype=ip_adapter.dtype) + ip_k = ip_k * weight + ip_v = ip_v_offset + ip_v_mean * weight k.append(ip_k) v.append(ip_v) diff --git a/fooocus_version.py b/fooocus_version.py index c07bdebc..7b8301ff 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.2' +version = '2.1.3'