revise noise formulation

revise noise formulation
This commit is contained in:
lllyasviel 2023-10-22 06:02:35 -07:00
parent 3acf1d6494
commit 7d81eeed7e
4 changed files with 26 additions and 66 deletions

View File

@ -1 +1 @@
version = '2.1.725'
version = '2.1.726'

View File

@ -213,67 +213,24 @@ def get_previewer(model):
return preview_function
@torch.no_grad()
@torch.inference_mode()
def prepare_noise(latent_image, generator, noise_inds=None):
if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout,
generator=generator, device="cpu")
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1] + 1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout,
generator=generator, device="cpu")
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, dim=0)
return noises
@torch.no_grad()
@torch.inference_mode()
def prepare_additive_noise(latent_image, generator, noise_inds=None):
B, C, H, W = latent_image.shape
if noise_inds is None:
return torch.rand([B, 1, H, W], dtype=latent_image.dtype, layout=latent_image.layout,
generator=generator, device="cpu") * 2.0 - 1.0
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = []
for i in range(unique_inds[-1] + 1):
noise = torch.rand([1, 1, H, W], dtype=latent_image.dtype, layout=latent_image.layout,
generator=generator, device="cpu") * 2.0 - 1.0
if i in unique_inds:
noises.append(noise)
noises = [noises[i] for i in inverse]
noises = torch.cat(noises, dim=0)
return noises
@torch.no_grad()
@torch.inference_mode()
def ksampler(model, positive, negative, latent, seed=None, steps=30, cfg=7.0, sampler_name='dpmpp_2m_sde_gpu',
scheduler='karras', denoise=1.0, disable_noise=False, start_step=None, last_step=None,
force_full_denoise=False, callback_function=None, refiner=None, refiner_switch=-1,
previewer_start=None, previewer_end=None, sigmas=None, extra_noise=None):
previewer_start=None, previewer_end=None, sigmas=None, noise=None):
if sigmas is not None:
sigmas = sigmas.clone().to(fcbh.model_management.get_torch_device())
latent_image = latent["samples"]
batch_inds = latent["batch_index"] if "batch_index" in latent else None
rng = torch.manual_seed(seed)
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
noise = prepare_noise(latent_image, rng, batch_inds)
if isinstance(extra_noise, float):
additive_noise = prepare_additive_noise(latent_image, rng, batch_inds)
noise = noise + additive_noise * extra_noise
if noise is None:
if disable_noise:
noise = torch.zeros(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, device="cpu")
else:
batch_inds = latent["batch_index"] if "batch_index" in latent else None
noise = fcbh.sample.prepare_noise(latent_image, seed, batch_inds)
noise_mask = None
if "noise_mask" in latent:

View File

@ -270,13 +270,11 @@ refresh_everything(
@torch.no_grad()
@torch.inference_mode()
def vae_parse(latent, k=1.0):
def vae_parse(latent):
if final_refiner_vae is None:
result = latent["samples"]
else:
result = vae_interpose.parse(latent["samples"])
if k != 1.0:
result = result * k
return latent
result = vae_interpose.parse(latent["samples"])
return {'samples': result}
@ -433,6 +431,8 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
if refiner_swap_method == 'vae':
modules.patch.eps_record = 'vae'
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.unswap()
@ -458,13 +458,9 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
target_model = final_unet
print('Use base model to refine itself - this may because of developer mode.')
# Fooocus' vae parameters
k_data = 1.025
k_noise = 0.25
sampled_latent = vae_parse(sampled_latent)
k_sigmas = 1.4
sampled_latent = vae_parse(sampled_latent, k=k_data)
sigmas = calculate_sigmas(sampler=sampler_name,
scheduler=scheduler_name,
model=target_model.model,
@ -472,6 +468,9 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
denoise=denoise)[switch:] * k_sigmas
len_sigmas = len(sigmas) - 1
assert isinstance(modules.patch.eps_record, torch.Tensor)
residual_noise = modules.patch.eps_record / modules.patch.eps_record.std()
if modules.inpaint_worker.current_task is not None:
modules.inpaint_worker.current_task.swap()
@ -481,7 +480,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
negative=clip_separate(negative_cond, target_model=target_model.model, target_clip=final_clip),
latent=sampled_latent,
steps=len_sigmas, start_step=0, last_step=len_sigmas, disable_noise=False, force_full_denoise=True,
seed=image_seed + 1, # Avoid artifacts
seed=image_seed,
denoise=denoise,
callback_function=callback,
cfg=cfg_scale,
@ -490,7 +489,7 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
previewer_start=switch,
previewer_end=steps,
sigmas=sigmas,
extra_noise=k_noise
noise=residual_noise
)
target_model = final_refiner_vae
@ -499,4 +498,5 @@ def process_diffusion(positive_cond, negative_cond, steps, switch, width, height
decoded_latent = core.decode_vae(vae=target_model, latent_image=sampled_latent, tiled=tiled)
images = core.pytorch_to_numpy(decoded_latent)
modules.patch.eps_record = None
return images

View File

@ -38,6 +38,7 @@ cfg_x0 = 0.0
cfg_s = 1.0
cfg_cin = 1.0
adaptive_cfg = 0.7
eps_record = None
def calculate_weight_patched(self, patches, weight, key):
@ -192,10 +193,12 @@ def patched_sampler_cfg_function(args):
def patched_discrete_eps_ddpm_denoiser_forward(self, input, sigma, **kwargs):
global cfg_x0, cfg_s, cfg_cin
global cfg_x0, cfg_s, cfg_cin, eps_record
c_out, c_in = [utils.append_dims(x, input.ndim) for x in self.get_scalings(sigma)]
cfg_x0, cfg_s, cfg_cin = input, c_out, c_in
eps = self.get_eps(input * c_in, self.sigma_to_t(sigma), **kwargs)
if eps_record is not None:
eps_record = eps.clone().cpu()
return input + eps * c_out