Merge branch 'main_upstream'
This commit is contained in:
commit
f52a356cfb
|
|
@ -2,12 +2,13 @@ import torch
|
|||
import ldm_patched.modules.clip_vision
|
||||
import safetensors.torch as sf
|
||||
import ldm_patched.modules.model_management as model_management
|
||||
import contextlib
|
||||
import ldm_patched.ldm.modules.attention as attention
|
||||
|
||||
from extras.resampler import Resampler
|
||||
from ldm_patched.modules.model_patcher import ModelPatcher
|
||||
from modules.core import numpy_to_pytorch
|
||||
from modules.ops import use_patched_ops
|
||||
from ldm_patched.modules.ops import manual_cast
|
||||
|
||||
|
||||
SD_V12_CHANNELS = [320] * 4 + [640] * 4 + [1280] * 4 + [1280] * 6 + [640] * 6 + [320] * 6 + [1280] * 2
|
||||
|
|
@ -116,14 +117,16 @@ def load_ip_adapter(clip_vision_path, ip_negative_path, ip_adapter_path):
|
|||
clip_extra_context_tokens = ip_state_dict["image_proj"]["proj.weight"].shape[0] // cross_attention_dim
|
||||
clip_embeddings_dim = None
|
||||
|
||||
ip_adapter = IPAdapterModel(
|
||||
ip_state_dict,
|
||||
plus=plus,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
clip_embeddings_dim=clip_embeddings_dim,
|
||||
clip_extra_context_tokens=clip_extra_context_tokens,
|
||||
sdxl_plus=sdxl_plus
|
||||
)
|
||||
with use_patched_ops(manual_cast):
|
||||
ip_adapter = IPAdapterModel(
|
||||
ip_state_dict,
|
||||
plus=plus,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
clip_embeddings_dim=clip_embeddings_dim,
|
||||
clip_extra_context_tokens=clip_extra_context_tokens,
|
||||
sdxl_plus=sdxl_plus
|
||||
)
|
||||
|
||||
ip_adapter.sdxl = sdxl
|
||||
ip_adapter.load_device = load_device
|
||||
ip_adapter.offload_device = offload_device
|
||||
|
|
|
|||
|
|
@ -108,8 +108,7 @@ class Resampler(nn.Module):
|
|||
)
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
latents = self.latents.repeat(x.size(0), 1, 1).to(x)
|
||||
|
||||
x = self.proj_in(x)
|
||||
|
||||
|
|
@ -118,4 +117,4 @@ class Resampler(nn.Module):
|
|||
latents = ff(latents) + latents
|
||||
|
||||
latents = self.proj_out(latents)
|
||||
return self.norm_out(latents)
|
||||
return self.norm_out(latents)
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
version = '2.1.857'
|
||||
version = '2.1.859'
|
||||
|
|
|
|||
Loading…
Reference in New Issue