fix lowvram (#562)

This commit is contained in:
lllyasviel 2023-10-08 00:50:25 -07:00 committed by GitHub
parent 9506714985
commit 6c03faf568
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 10 additions and 2 deletions

View File

@ -213,6 +213,7 @@ def patch_model(model, ip_tasks):
org_dtype = n.dtype
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
cond_or_uncond = extra_options['cond_or_uncond']
batch_size = int(context_attn2.shape[0])
with torch.autocast("cuda", dtype=ip_adapter.dtype):
q = n
@ -223,9 +224,16 @@ def patch_model(model, ip_tasks):
for cn_img, cn_stop, cn_weight, cache in tasks:
if current_step < cn_stop:
ip_k, ip_v = None, None
if ip_index in cache:
ip_k, ip_v = cache[ip_index]
else:
if int(ip_k.shape[0]) != batch_size:
ip_k = None
if int(ip_v.shape[0]) != batch_size:
ip_v = None
if ip_k is None or ip_v is None:
ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1)

View File

@ -1 +1 @@
version = '2.1.1'
version = '2.1.2'