From 6c03faf568c17e214c074e6478084daeac745bc2 Mon Sep 17 00:00:00 2001 From: lllyasviel Date: Sun, 8 Oct 2023 00:50:25 -0700 Subject: [PATCH] fix lowvram (#562) --- fooocus_extras/ip_adapter.py | 10 +++++++++- fooocus_version.py | 2 +- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/fooocus_extras/ip_adapter.py b/fooocus_extras/ip_adapter.py index 3fdc56b3..28d04e45 100644 --- a/fooocus_extras/ip_adapter.py +++ b/fooocus_extras/ip_adapter.py @@ -213,6 +213,7 @@ def patch_model(model, ip_tasks): org_dtype = n.dtype current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0]) cond_or_uncond = extra_options['cond_or_uncond'] + batch_size = int(context_attn2.shape[0]) with torch.autocast("cuda", dtype=ip_adapter.dtype): q = n @@ -223,9 +224,16 @@ def patch_model(model, ip_tasks): for cn_img, cn_stop, cn_weight, cache in tasks: if current_step < cn_stop: + ip_k, ip_v = None, None + if ip_index in cache: ip_k, ip_v = cache[ip_index] - else: + if int(ip_k.shape[0]) != batch_size: + ip_k = None + if int(ip_v.shape[0]) != batch_size: + ip_v = None + + if ip_k is None or ip_v is None: ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype) cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1) diff --git a/fooocus_version.py b/fooocus_version.py index c9e52937..c07bdebc 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.1' +version = '2.1.2'