fix lowvram (#562)
This commit is contained in:
parent
9506714985
commit
6c03faf568
|
|
@ -213,6 +213,7 @@ def patch_model(model, ip_tasks):
|
|||
org_dtype = n.dtype
|
||||
current_step = float(model.model.diffusion_model.current_step.detach().cpu().numpy()[0])
|
||||
cond_or_uncond = extra_options['cond_or_uncond']
|
||||
batch_size = int(context_attn2.shape[0])
|
||||
|
||||
with torch.autocast("cuda", dtype=ip_adapter.dtype):
|
||||
q = n
|
||||
|
|
@ -223,9 +224,16 @@ def patch_model(model, ip_tasks):
|
|||
|
||||
for cn_img, cn_stop, cn_weight, cache in tasks:
|
||||
if current_step < cn_stop:
|
||||
ip_k, ip_v = None, None
|
||||
|
||||
if ip_index in cache:
|
||||
ip_k, ip_v = cache[ip_index]
|
||||
else:
|
||||
if int(ip_k.shape[0]) != batch_size:
|
||||
ip_k = None
|
||||
if int(ip_v.shape[0]) != batch_size:
|
||||
ip_v = None
|
||||
|
||||
if ip_k is None or ip_v is None:
|
||||
ip_model_k.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
ip_model_v.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype)
|
||||
cond = cn_img.to(device=ip_adapter.load_device, dtype=ip_adapter.dtype).repeat(batch_prompt, 1, 1)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
version = '2.1.1'
|
||||
version = '2.1.2'
|
||||
|
|
|
|||
Loading…
Reference in New Issue