diff --git a/backend/headless/comfy/ldm/modules/attention.py b/backend/headless/comfy/ldm/modules/attention.py index fcae6b66..ac0d9c8c 100644 --- a/backend/headless/comfy/ldm/modules/attention.py +++ b/backend/headless/comfy/ldm/modules/attention.py @@ -94,253 +94,220 @@ def zero_module(module): def Normalize(in_channels, dtype=None, device=None): return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True, dtype=dtype, device=device) +def attention_basic(q, k, v, heads, mask=None): + h = heads + scale = (q.shape[-1] // heads) ** -0.5 + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) -class SpatialSelfAttention(nn.Module): - def __init__(self, in_channels): - super().__init__() - self.in_channels = in_channels + # force cast to fp32 to avoid overflowing + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + q, k = q.float(), k.float() + sim = einsum('b i d, b j d -> b i j', q, k) * scale + else: + sim = einsum('b i d, b j d -> b i j', q, k) * scale - self.norm = Normalize(in_channels) - self.q = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.k = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.v = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) - self.proj_out = torch.nn.Conv2d(in_channels, - in_channels, - kernel_size=1, - stride=1, - padding=0) + del q, k - def forward(self, x): - h_ = x - h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) - v = self.v(h_) + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) - # compute attention - b,c,h,w = q.shape - q = rearrange(q, 'b c h w -> b (h w) c') - k = rearrange(k, 'b c h w -> b c (h w)') - w_ = torch.einsum('bij,bjk->bik', q, k) + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) - - # attend to values - v = rearrange(v, 'b c h w -> b c (h w)') - w_ = rearrange(w_, 'b i j -> b j i') - h_ = torch.einsum('bij,bjk->bik', v, w_) - h_ = rearrange(h_, 'b c (h w) -> b c h w', h=h) - h_ = self.proj_out(h_) - - return x+h_ + out = einsum('b i j, b j d -> b i d', sim.to(v.dtype), v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return out -class CrossAttentionBirchSan(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) +def attention_sub_quad(query, key, value, heads, mask=None): + scale = (query.shape[-1] // heads) ** -0.5 + query = query.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) + key_t = key.transpose(1,2).unflatten(1, (heads, -1)).flatten(end_dim=1) + del key + value = value.unflatten(-1, (heads, -1)).transpose(1,2).flatten(end_dim=1) - self.scale = dim_head ** -0.5 - self.heads = heads + dtype = query.dtype + upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 + if upcast_attention: + bytes_per_token = torch.finfo(torch.float32).bits//8 + else: + bytes_per_token = torch.finfo(query.dtype).bits//8 + batch_x_heads, q_tokens, _ = query.shape + _, _, k_tokens = key_t.shape + qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) + mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) - self.to_out = nn.Sequential( - operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), - nn.Dropout(dropout) - ) + chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD - def forward(self, x, context=None, value=None, mask=None): - h = self.heads + kv_chunk_size_min = None - query = self.to_q(x) - context = default(context, x) - key = self.to_k(context) - if value is not None: - value = self.to_v(value) - else: - value = self.to_v(context) + #not sure at all about the math here + #TODO: tweak this + if mem_free_total > 8192 * 1024 * 1024 * 1.3: + query_chunk_size_x = 1024 * 4 + elif mem_free_total > 4096 * 1024 * 1024 * 1.3: + query_chunk_size_x = 1024 * 2 + else: + query_chunk_size_x = 1024 + kv_chunk_size_min_x = None + kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 + if kv_chunk_size_x < 1024: + kv_chunk_size_x = None - del context, x + if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: + # the big matmul fits into our memory limit; do everything in 1 chunk, + # i.e. send it down the unchunked fast-path + query_chunk_size = q_tokens + kv_chunk_size = k_tokens + else: + query_chunk_size = query_chunk_size_x + kv_chunk_size = kv_chunk_size_x + kv_chunk_size_min = kv_chunk_size_min_x - query = query.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) - key_t = key.transpose(1,2).unflatten(1, (self.heads, -1)).flatten(end_dim=1) - del key - value = value.unflatten(-1, (self.heads, -1)).transpose(1,2).flatten(end_dim=1) + hidden_states = efficient_dot_product_attention( + query, + key_t, + value, + query_chunk_size=query_chunk_size, + kv_chunk_size=kv_chunk_size, + kv_chunk_size_min=kv_chunk_size_min, + use_checkpoint=False, + upcast_attention=upcast_attention, + ) - dtype = query.dtype - upcast_attention = _ATTN_PRECISION =="fp32" and query.dtype != torch.float32 - if upcast_attention: - bytes_per_token = torch.finfo(torch.float32).bits//8 - else: - bytes_per_token = torch.finfo(query.dtype).bits//8 - batch_x_heads, q_tokens, _ = query.shape - _, _, k_tokens = key_t.shape - qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens + hidden_states = hidden_states.to(dtype) - mem_free_total, mem_free_torch = model_management.get_free_memory(query.device, True) + hidden_states = hidden_states.unflatten(0, (-1, heads)).transpose(1,2).flatten(start_dim=2) + return hidden_states - chunk_threshold_bytes = mem_free_torch * 0.5 #Using only this seems to work better on AMD +def attention_split(q, k, v, heads, mask=None): + scale = (q.shape[-1] // heads) ** -0.5 + h = heads + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - kv_chunk_size_min = None + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - #not sure at all about the math here - #TODO: tweak this - if mem_free_total > 8192 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 4 - elif mem_free_total > 4096 * 1024 * 1024 * 1.3: - query_chunk_size_x = 1024 * 2 - else: - query_chunk_size_x = 1024 - kv_chunk_size_min_x = None - kv_chunk_size_x = (int((chunk_threshold_bytes // (batch_x_heads * bytes_per_token * query_chunk_size_x)) * 2.0) // 1024) * 1024 - if kv_chunk_size_x < 1024: - kv_chunk_size_x = None + mem_free_total = model_management.get_free_memory(q.device) - if chunk_threshold_bytes is not None and qk_matmul_size_bytes <= chunk_threshold_bytes: - # the big matmul fits into our memory limit; do everything in 1 chunk, - # i.e. send it down the unchunked fast-path - query_chunk_size = q_tokens - kv_chunk_size = k_tokens - else: - query_chunk_size = query_chunk_size_x - kv_chunk_size = kv_chunk_size_x - kv_chunk_size_min = kv_chunk_size_min_x - - hidden_states = efficient_dot_product_attention( - query, - key_t, - value, - query_chunk_size=query_chunk_size, - kv_chunk_size=kv_chunk_size, - kv_chunk_size_min=kv_chunk_size_min, - use_checkpoint=self.training, - upcast_attention=upcast_attention, - ) - - hidden_states = hidden_states.to(dtype) - - hidden_states = hidden_states.unflatten(0, (-1, self.heads)).transpose(1,2).flatten(start_dim=2) - - out_proj, dropout = self.to_out - hidden_states = out_proj(hidden_states) - hidden_states = dropout(hidden_states) - - return hidden_states + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 -class CrossAttentionDoggettx(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - self.scale = dim_head ** -0.5 - self.heads = heads + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - - self.to_out = nn.Sequential( - operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), - nn.Dropout(dropout) - ) - - def forward(self, x, context=None, value=None, mask=None): - h = self.heads - - q_in = self.to_q(x) - context = default(context, x) - k_in = self.to_k(context) - if value is not None: - v_in = self.to_v(value) - del value - else: - v_in = self.to_v(context) - del context, x - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) - del q_in, k_in, v_in - - r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype) - - mem_free_total = model_management.get_free_memory(q.device) - - gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() - modifier = 3 if q.element_size() == 2 else 2.5 - mem_required = tensor_size * modifier - steps = 1 - - - if mem_required > mem_free_total: - steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") - - if steps > 64: - max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 - raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' - f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') - - # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) - first_op_done = False - cleared_cache = False - while True: - try: - slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] - for i in range(0, q.shape[1], slice_size): - end = i + slice_size - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * self.scale - else: - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - first_op_done = True - - s2 = s1.softmax(dim=-1).to(v.dtype) - del s1 - - r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) - del s2 - break - except model_management.OOM_EXCEPTION as e: - if first_op_done == False: - model_management.soft_empty_cache(True) - if cleared_cache == False: - cleared_cache = True - print("out of memory error, emptying cache and trying again") - continue - steps *= 2 - if steps > 64: - raise e - print("out of memory error, increasing steps and trying again", steps) + # print("steps", steps, mem_required, mem_free_total, modifier, q.element_size(), tensor_size) + first_op_done = False + cleared_cache = False + while True: + try: + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + if _ATTN_PRECISION =="fp32": + with torch.autocast(enabled=False, device_type = 'cuda'): + s1 = einsum('b i d, b j d -> b i j', q[:, i:end].float(), k.float()) * scale else: + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * scale + first_op_done = True + + s2 = s1.softmax(dim=-1).to(v.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + break + except model_management.OOM_EXCEPTION as e: + if first_op_done == False: + model_management.soft_empty_cache(True) + if cleared_cache == False: + cleared_cache = True + print("out of memory error, emptying cache and trying again") + continue + steps *= 2 + if steps > 64: raise e + print("out of memory error, increasing steps and trying again", steps) + else: + raise e - del q, k, v + del q, k, v - r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) - del r1 + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + return r2 - return self.to_out(r2) +def attention_xformers(q, k, v, heads, mask=None): + b, _, _ = q.shape + q, k, v = map( + lambda t: t.unsqueeze(3) + .reshape(b, t.shape[1], heads, -1) + .permute(0, 2, 1, 3) + .reshape(b * heads, t.shape[1], -1) + .contiguous(), + (q, k, v), + ) + + # actually compute the attention, what we cannot get enough of + out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None) + + if exists(mask): + raise NotImplementedError + out = ( + out.unsqueeze(0) + .reshape(b, heads, out.shape[1], -1) + .permute(0, 2, 1, 3) + .reshape(b, out.shape[1], -1) + ) + return out + +def attention_pytorch(q, k, v, heads, mask=None): + b, _, dim_head = q.shape + dim_head //= heads + q, k, v = map( + lambda t: t.view(b, -1, heads, dim_head).transpose(1, 2), + (q, k, v), + ) + + out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0, is_causal=False) + + if exists(mask): + raise NotImplementedError + out = ( + out.transpose(1, 2).reshape(b, -1, heads * dim_head) + ) + return out + +optimized_attention = attention_basic + +if model_management.xformers_enabled(): + print("Using xformers cross attention") + optimized_attention = attention_xformers +elif model_management.pytorch_attention_enabled(): + print("Using pytorch cross attention") + optimized_attention = attention_pytorch +else: + if args.use_split_cross_attention: + print("Using split optimization for cross attention") + optimized_attention = attention_split + else: + print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") + optimized_attention = attention_sub_quad class CrossAttention(nn.Module): def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): @@ -348,62 +315,6 @@ class CrossAttention(nn.Module): inner_dim = dim_head * heads context_dim = default(context_dim, query_dim) - self.scale = dim_head ** -0.5 - self.heads = heads - - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - - self.to_out = nn.Sequential( - operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), - nn.Dropout(dropout) - ) - - def forward(self, x, context=None, value=None, mask=None): - h = self.heads - - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - if value is not None: - v = self.to_v(value) - del value - else: - v = self.to_v(context) - - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - - # force cast to fp32 to avoid overflowing - if _ATTN_PRECISION =="fp32": - with torch.autocast(enabled=False, device_type = 'cuda'): - q, k = q.float(), k.float() - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - else: - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale - - del q, k - - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) - - # attention, what we cannot get enough of - sim = sim.softmax(dim=-1) - - out = einsum('b i j, b j d -> b i d', sim, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) - -class MemoryEfficientCrossAttention(nn.Module): - # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223 - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0, dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - self.heads = heads self.dim_head = dim_head @@ -412,7 +323,6 @@ class MemoryEfficientCrossAttention(nn.Module): self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None def forward(self, x, context=None, value=None, mask=None): q = self.to_q(x) @@ -424,85 +334,9 @@ class MemoryEfficientCrossAttention(nn.Module): else: v = self.to_v(context) - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], self.heads, self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b * self.heads, t.shape[1], self.dim_head) - .contiguous(), - (q, k, v), - ) - - # actually compute the attention, what we cannot get enough of - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=self.attention_op) - - if exists(mask): - raise NotImplementedError - out = ( - out.unsqueeze(0) - .reshape(b, self.heads, out.shape[1], self.dim_head) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], self.heads * self.dim_head) - ) + out = optimized_attention(q, k, v, self.heads, mask) return self.to_out(out) -class CrossAttentionPytorch(nn.Module): - def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0., dtype=None, device=None, operations=comfy.ops): - super().__init__() - inner_dim = dim_head * heads - context_dim = default(context_dim, query_dim) - - self.heads = heads - self.dim_head = dim_head - - self.to_q = operations.Linear(query_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_k = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - self.to_v = operations.Linear(context_dim, inner_dim, bias=False, dtype=dtype, device=device) - - self.to_out = nn.Sequential(operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)) - self.attention_op: Optional[Any] = None - - def forward(self, x, context=None, value=None, mask=None): - q = self.to_q(x) - context = default(context, x) - k = self.to_k(context) - if value is not None: - v = self.to_v(value) - del value - else: - v = self.to_v(context) - - b, _, _ = q.shape - q, k, v = map( - lambda t: t.view(b, -1, self.heads, self.dim_head).transpose(1, 2), - (q, k, v), - ) - - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - - if exists(mask): - raise NotImplementedError - out = ( - out.transpose(1, 2).reshape(b, -1, self.heads * self.dim_head) - ) - - return self.to_out(out) - -if model_management.xformers_enabled(): - print("Using xformers cross attention") - CrossAttention = MemoryEfficientCrossAttention -elif model_management.pytorch_attention_enabled(): - print("Using pytorch cross attention") - CrossAttention = CrossAttentionPytorch -else: - if args.use_split_cross_attention: - print("Using split optimization for cross attention") - CrossAttention = CrossAttentionDoggettx - else: - print("Using sub quadratic optimization for cross attention, if you have memory or speed issues try using: --use-split-cross-attention") - CrossAttention = CrossAttentionBirchSan - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True, diff --git a/backend/headless/comfy/ldm/modules/diffusionmodules/model.py b/backend/headless/comfy/ldm/modules/diffusionmodules/model.py index 5f38640c..6576df4b 100644 --- a/backend/headless/comfy/ldm/modules/diffusionmodules/model.py +++ b/backend/headless/comfy/ldm/modules/diffusionmodules/model.py @@ -6,7 +6,6 @@ import numpy as np from einops import rearrange from typing import Optional, Any -from ..attention import MemoryEfficientCrossAttention from comfy import model_management import comfy.ops @@ -352,20 +351,11 @@ class MemoryEfficientAttnBlockPytorch(nn.Module): out = self.proj_out(out) return x+out -class MemoryEfficientCrossAttentionWrapper(MemoryEfficientCrossAttention): - def forward(self, x, context=None, mask=None): - b, c, h, w = x.shape - x = rearrange(x, 'b c h w -> b (h w) c') - out = super().forward(x, context=context, mask=mask) - out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w, c=c) - return x + out - - def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): assert attn_type in ["vanilla", "vanilla-xformers", "memory-efficient-cross-attn", "linear", "none"], f'attn_type {attn_type} unknown' if model_management.xformers_enabled_vae() and attn_type == "vanilla": attn_type = "vanilla-xformers" - if model_management.pytorch_attention_enabled() and attn_type == "vanilla": + elif model_management.pytorch_attention_enabled() and attn_type == "vanilla": attn_type = "vanilla-pytorch" print(f"making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": @@ -376,9 +366,6 @@ def make_attn(in_channels, attn_type="vanilla", attn_kwargs=None): return MemoryEfficientAttnBlock(in_channels) elif attn_type == "vanilla-pytorch": return MemoryEfficientAttnBlockPytorch(in_channels) - elif type == "memory-efficient-cross-attn": - attn_kwargs["query_dim"] = in_channels - return MemoryEfficientCrossAttentionWrapper(**attn_kwargs) elif attn_type == "none": return nn.Identity(in_channels) else: diff --git a/backend/headless/comfy/model_management.py b/backend/headless/comfy/model_management.py index 8b896372..3c390d9c 100644 --- a/backend/headless/comfy/model_management.py +++ b/backend/headless/comfy/model_management.py @@ -154,14 +154,18 @@ def is_nvidia(): return True return False -ENABLE_PYTORCH_ATTENTION = args.use_pytorch_cross_attention +ENABLE_PYTORCH_ATTENTION = False +if args.use_pytorch_cross_attention: + ENABLE_PYTORCH_ATTENTION = True + XFORMERS_IS_AVAILABLE = False + VAE_DTYPE = torch.float32 try: if is_nvidia(): torch_version = torch.version.__version__ if int(torch_version[0]) >= 2: - if ENABLE_PYTORCH_ATTENTION == False and XFORMERS_IS_AVAILABLE == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: + if ENABLE_PYTORCH_ATTENTION == False and args.use_split_cross_attention == False and args.use_quad_cross_attention == False: ENABLE_PYTORCH_ATTENTION = True if torch.cuda.is_bf16_supported(): VAE_DTYPE = torch.bfloat16 @@ -186,7 +190,6 @@ if ENABLE_PYTORCH_ATTENTION: torch.backends.cuda.enable_math_sdp(True) torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) - XFORMERS_IS_AVAILABLE = False if args.lowvram: set_vram_to = VRAMState.LOW_VRAM @@ -354,6 +357,8 @@ def load_models_gpu(models, memory_required=0): current_loaded_models.insert(0, current_loaded_models.pop(index)) models_already_loaded.append(loaded_model) else: + if hasattr(x, "model"): + print(f"Requested to load {x.model.__class__.__name__}") models_to_load.append(loaded_model) if len(models_to_load) == 0: @@ -363,7 +368,7 @@ def load_models_gpu(models, memory_required=0): free_memory(extra_mem, d, models_already_loaded) return - print("loading new") + print(f"Loading {len(models_to_load)} new model{'s' if len(models_to_load) > 1 else ''}") total_memory_required = {} for loaded_model in models_to_load: @@ -405,7 +410,6 @@ def load_model_gpu(model): def cleanup_models(): to_delete = [] for i in range(len(current_loaded_models)): - print(sys.getrefcount(current_loaded_models[i].model)) if sys.getrefcount(current_loaded_models[i].model) <= 2: to_delete = [i] + to_delete diff --git a/backend/headless/comfy/utils.py b/backend/headless/comfy/utils.py index 7843b58c..df016ef9 100644 --- a/backend/headless/comfy/utils.py +++ b/backend/headless/comfy/utils.py @@ -408,6 +408,10 @@ def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_am output[b:b+1] = out/out_div return output +PROGRESS_BAR_ENABLED = True +def set_progress_bar_enabled(enabled): + global PROGRESS_BAR_ENABLED + PROGRESS_BAR_ENABLED = enabled PROGRESS_BAR_HOOK = None def set_progress_bar_global_hook(function): diff --git a/backend/headless/comfy_extras/nodes_custom_sampler.py b/backend/headless/comfy_extras/nodes_custom_sampler.py index 9391c714..b52ad8fb 100644 --- a/backend/headless/comfy_extras/nodes_custom_sampler.py +++ b/backend/headless/comfy_extras/nodes_custom_sampler.py @@ -3,6 +3,7 @@ import comfy.sample from comfy.k_diffusion import sampling as k_diffusion_sampling import latent_preview import torch +import comfy.utils class BasicScheduler: @@ -219,7 +220,7 @@ class SamplerCustom: x0_output = {} callback = latent_preview.prepare_callback(model, sigmas.shape[-1] - 1, x0_output) - disable_pbar = False + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=noise_seed) out = latent.copy() diff --git a/backend/headless/folder_paths.py b/backend/headless/folder_paths.py index 898513b0..5d121b44 100644 --- a/backend/headless/folder_paths.py +++ b/backend/headless/folder_paths.py @@ -29,6 +29,8 @@ folder_names_and_paths["custom_nodes"] = ([os.path.join(base_path, "custom_nodes folder_names_and_paths["hypernetworks"] = ([os.path.join(models_dir, "hypernetworks")], supported_pt_extensions) +folder_names_and_paths["classifiers"] = ([os.path.join(models_dir, "classifiers")], {""}) + output_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "output") temp_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "temp") input_directory = os.path.join(os.path.dirname(os.path.realpath(__file__)), "input") diff --git a/backend/headless/nodes.py b/backend/headless/nodes.py index 16bf07cc..208cbc84 100644 --- a/backend/headless/nodes.py +++ b/backend/headless/nodes.py @@ -1202,7 +1202,7 @@ def common_ksampler(model, seed, steps, cfg, sampler_name, scheduler, positive, noise_mask = latent["noise_mask"] callback = latent_preview.prepare_callback(model, steps) - disable_pbar = False + disable_pbar = not comfy.utils.PROGRESS_BAR_ENABLED samples = comfy.sample.sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative, latent_image, denoise=denoise, disable_noise=disable_noise, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, noise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/fooocus_extras/ip_adapter.py b/fooocus_extras/ip_adapter.py index aa555495..b30a5646 100644 --- a/fooocus_extras/ip_adapter.py +++ b/fooocus_extras/ip_adapter.py @@ -3,47 +3,18 @@ import comfy.clip_vision import safetensors.torch as sf import comfy.model_management as model_management import contextlib +import comfy.ldm.modules.attention as attention from fooocus_extras.resampler import Resampler from comfy.model_patcher import ModelPatcher -if model_management.xformers_enabled(): - import xformers - import xformers.ops - - SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2 SD_XL_CHANNELS = [640] * 8 + [1280] * 40 + [1280] * 60 + [640] * 12 + [1280] * 20 def sdp(q, k, v, extra_options): - if model_management.xformers_enabled(): - b, _, _ = q.shape - q, k, v = map( - lambda t: t.unsqueeze(3) - .reshape(b, t.shape[1], extra_options["n_heads"], extra_options["dim_head"]) - .permute(0, 2, 1, 3) - .reshape(b * extra_options["n_heads"], t.shape[1], extra_options["dim_head"]) - .contiguous(), - (q, k, v), - ) - out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=None) - out = ( - out.unsqueeze(0) - .reshape(b, extra_options["n_heads"], out.shape[1], extra_options["dim_head"]) - .permute(0, 2, 1, 3) - .reshape(b, out.shape[1], extra_options["n_heads"] * extra_options["dim_head"]) - ) - else: - b, _, _ = q.shape - q, k, v = map( - lambda t: t.view(b, -1, extra_options["n_heads"], extra_options["dim_head"]).transpose(1, 2), - (q, k, v), - ) - out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) - out = out.transpose(1, 2).reshape(b, -1, extra_options["n_heads"] * extra_options["dim_head"]) - return out + return attention.optimized_attention(q, k, v, heads=extra_options["n_heads"], mask=None) class ImageProjModel(torch.nn.Module): diff --git a/fooocus_version.py b/fooocus_version.py index 61dec461..eba47c0e 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.50' +version = '2.1.51'