diff --git a/backend/headless/fcbh/controlnet.py b/backend/headless/fcbh/controlnet.py index a0858399..dcdd0c1f 100644 --- a/backend/headless/fcbh/controlnet.py +++ b/backend/headless/fcbh/controlnet.py @@ -416,7 +416,7 @@ class T2IAdapter(ControlBase): if control_prev is not None: return control_prev else: - return {} + return None if self.cond_hint is None or x_noisy.shape[2] * 8 != self.cond_hint.shape[2] or x_noisy.shape[3] * 8 != self.cond_hint.shape[3]: if self.cond_hint is not None: diff --git a/backend/headless/fcbh/ldm/modules/attention.py b/backend/headless/fcbh/ldm/modules/attention.py index 1a3b9f02..a0af3851 100644 --- a/backend/headless/fcbh/ldm/modules/attention.py +++ b/backend/headless/fcbh/ldm/modules/attention.py @@ -95,9 +95,19 @@ def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) def attention_basic(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + h = heads - scale = (q.shape[-1] // heads) ** -0.5 - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) # force cast to fp32 to avoid overflowing if _ATTN_PRECISION =="fp32": @@ -119,16 +129,24 @@ def attention_basic(q, k, v, heads, mask=None): sim = sim.softmax(dim=-1) out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = ( + out.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) return out def attention_sub_quad(query, key, value, heads, mask=None): - scale = (query.shape[-1] // heads) ** -0.5 - query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) - key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1) - del key - value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) + b, _, dim_head = query.shape + dim_head //= heads + + scale = dim_head ** -0.5 + query = query.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + value = value.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 1, 3).reshape(b * heads, -1, dim_head) + + key = key.unsqueeze(3).reshape(b, -1, heads, dim_head).permute(0, 2, 3, 1).reshape(b * heads, dim_head, -1) dtype = query.dtype upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 @@ -137,7 +155,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): else: bytes_per_token = torch.finfo(query.dtype).bits//8 batch_x_heads, q_tokens, _ = query.shape - _, _, k_tokens = key_t.shape + _, _, k_tokens = key.shape qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) @@ -171,7 +189,7 @@ def attention_sub_quad(query, key, value, heads, mask=None): hidden_states = efficient_dot_product_attention( query, - key_t, + key, value, query_chunk_size=query_chunk_size, kv_chunk_size=kv_chunk_size, @@ -186,9 +204,19 @@ def attention_sub_quad(query, key, value, heads, mask=None): return hidden_states def attention_split(q, k, v, heads, mask=None): - scale = (q.shape[-1] // heads) ** -0.5 + b, _, dim_head = q.shape + dim_head //= heads + scale = dim_head ** -0.5 + h = heads - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, -1, heads, dim_head) + .permute(0, 2, 1, 3) + .reshape(b * heads, -1, dim_head) + .contiguous(), + (q, k, v), + ) r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) @@ -248,9 +276,13 @@ def attention_split(q, k, v, heads, mask=None): del q, k, v - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 - return r2 + r1 = ( + r1.unsqueeze(0) + .reshape(b, heads, -1, dim_head) + .permute(0, 2, 1, 3) + .reshape(b, -1, heads * dim_head) + ) + return r1 def attention_xformers(q, k, v, heads, mask=None): b, _, dim_head = q.shape diff --git a/fooocus_version.py b/fooocus_version.py index e8adf2ec..47dcbe78 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.724' +version = '2.1.725'