diff --git a/backend/headless/fcbh/k_diffusion/sampling.py b/backend/headless/fcbh/k_diffusion/sampling.py index dd6f7bbe..761c2e0e 100644 --- a/backend/headless/fcbh/k_diffusion/sampling.py +++ b/backend/headless/fcbh/k_diffusion/sampling.py @@ -750,3 +750,61 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n if sigmas[i + 1] > 0: x += sigmas[i + 1] * noise_sampler(sigmas[i], sigmas[i + 1]) return x + + + +@torch.no_grad() +def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.): + # From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/ + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + s_end = sigmas[-1] + for i in trange(len(sigmas) - 1, disable=disable): + gamma = min(s_churn / (len(sigmas) - 1), 2 ** 0.5 - 1) if s_tmin <= sigmas[i] <= s_tmax else 0. + eps = torch.randn_like(x) * s_noise + sigma_hat = sigmas[i] * (gamma + 1) + if gamma > 0: + x = x + eps * (sigma_hat ** 2 - sigmas[i] ** 2) ** 0.5 + denoised = model(x, sigma_hat * s_in, **extra_args) + d = to_d(x, sigma_hat, denoised) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigma_hat, 'denoised': denoised}) + dt = sigmas[i + 1] - sigma_hat + if sigmas[i + 1] == s_end: + # Euler method + x = x + d * dt + elif sigmas[i + 2] == s_end: + + # Heun's method + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + + w = 2 * sigmas[0] + w2 = sigmas[i+1]/w + w1 = 1 - w2 + + d_prime = d * w1 + d_2 * w2 + + + x = x + d_prime * dt + + else: + # Heun++ + x_2 = x + d * dt + denoised_2 = model(x_2, sigmas[i + 1] * s_in, **extra_args) + d_2 = to_d(x_2, sigmas[i + 1], denoised_2) + dt_2 = sigmas[i + 2] - sigmas[i + 1] + + x_3 = x_2 + d_2 * dt_2 + denoised_3 = model(x_3, sigmas[i + 2] * s_in, **extra_args) + d_3 = to_d(x_3, sigmas[i + 2], denoised_3) + + w = 3 * sigmas[0] + w2 = sigmas[i + 1] / w + w3 = sigmas[i + 2] / w + w1 = 1 - w2 - w3 + + d_prime = w1 * d + w2 * d_2 + w3 * d_3 + x = x + d_prime * dt + return x diff --git a/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py b/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py index a2540e7d..1dcb70dd 100644 --- a/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py +++ b/backend/headless/fcbh/ldm/modules/diffusionmodules/openaimodel.py @@ -624,6 +624,11 @@ class UNetModel(nn.Module): transformer_options["block"] = ("input", id) h = forward_timestep_embed(module, h, emb, context, transformer_options) h = apply_control(h, control, 'input') + if "input_block_patch" in transformer_patches: + patch = transformer_patches["input_block_patch"] + for p in patch: + h = p(h, transformer_options) + hs.append(h) transformer_options["block"] = ("middle", 0) diff --git a/backend/headless/fcbh/model_patcher.py b/backend/headless/fcbh/model_patcher.py index 96456afc..a01d1753 100644 --- a/backend/headless/fcbh/model_patcher.py +++ b/backend/headless/fcbh/model_patcher.py @@ -96,6 +96,9 @@ class ModelPatcher: def set_model_attn2_output_patch(self, patch): self.set_model_patch(patch, "attn2_output_patch") + def set_model_input_block_patch(self, patch): + self.set_model_patch(patch, "input_block_patch") + def set_model_output_block_patch(self, patch): self.set_model_patch(patch, "output_block_patch") diff --git a/backend/headless/fcbh/samplers.py b/backend/headless/fcbh/samplers.py index 0f78a4ac..2530b2f7 100644 --- a/backend/headless/fcbh/samplers.py +++ b/backend/headless/fcbh/samplers.py @@ -518,46 +518,63 @@ class UNIPCBH2(Sampler): def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): return uni_pc.sample_unipc(model_wrap, noise, latent_image, sigmas, max_denoise=self.max_denoise(model_wrap, sigmas), extra_args=extra_args, noise_mask=denoise_mask, callback=callback, variant='bh2', disable=disable_pbar) -KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"] +class KSAMPLER(Sampler): + def __init__(self, sampler_function, extra_options={}, inpaint_options={}): + self.sampler_function = sampler_function + self.extra_options = extra_options + self.inpaint_options = inpaint_options + + def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): + extra_args["denoise_mask"] = denoise_mask + model_k = KSamplerX0Inpaint(model_wrap) + model_k.latent_image = latent_image + if self.inpaint_options.get("random", False): #TODO: Should this be the default? + generator = torch.manual_seed(extra_args.get("seed", 41) + 1) + model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) + else: + model_k.noise = noise + + if self.max_denoise(model_wrap, sigmas): + noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) + else: + noise = noise * sigmas[0] + + k_callback = None + total_steps = len(sigmas) - 1 + if callback is not None: + k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) + + if latent_image is not None: + noise += latent_image + + samples = self.sampler_function(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **self.extra_options) + return samples + + def ksampler(sampler_name, extra_options={}, inpaint_options={}): - class KSAMPLER(Sampler): - def sample(self, model_wrap, sigmas, extra_args, callback, noise, latent_image=None, denoise_mask=None, disable_pbar=False): - extra_args["denoise_mask"] = denoise_mask - model_k = KSamplerX0Inpaint(model_wrap) - model_k.latent_image = latent_image - if inpaint_options.get("random", False): #TODO: Should this be the default? - generator = torch.manual_seed(extra_args.get("seed", 41) + 1) - model_k.noise = torch.randn(noise.shape, generator=generator, device="cpu").to(noise.dtype).to(noise.device) - else: - model_k.noise = noise - - if self.max_denoise(model_wrap, sigmas): - noise = noise * torch.sqrt(1.0 + sigmas[0] ** 2.0) - else: - noise = noise * sigmas[0] - - k_callback = None - total_steps = len(sigmas) - 1 - if callback is not None: - k_callback = lambda x: callback(x["i"], x["denoised"], x["x"], total_steps) - + if sampler_name == "dpm_fast": + def dpm_fast_function(model, noise, sigmas, extra_args, callback, disable): sigma_min = sigmas[-1] if sigma_min == 0: sigma_min = sigmas[-2] + total_steps = len(sigmas) - 1 + return k_diffusion_sampling.sample_dpm_fast(model, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_fast_function + elif sampler_name == "dpm_adaptive": + def dpm_adaptive_function(model, noise, sigmas, extra_args, callback, disable): + sigma_min = sigmas[-1] + if sigma_min == 0: + sigma_min = sigmas[-2] + return k_diffusion_sampling.sample_dpm_adaptive(model, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=callback, disable=disable) + sampler_function = dpm_adaptive_function + else: + sampler_function = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name)) - if latent_image is not None: - noise += latent_image - if sampler_name == "dpm_fast": - samples = k_diffusion_sampling.sample_dpm_fast(model_k, noise, sigma_min, sigmas[0], total_steps, extra_args=extra_args, callback=k_callback, disable=disable_pbar) - elif sampler_name == "dpm_adaptive": - samples = k_diffusion_sampling.sample_dpm_adaptive(model_k, noise, sigma_min, sigmas[0], extra_args=extra_args, callback=k_callback, disable=disable_pbar) - else: - samples = getattr(k_diffusion_sampling, "sample_{}".format(sampler_name))(model_k, noise, sigmas, extra_args=extra_args, callback=k_callback, disable=disable_pbar, **extra_options) - return samples - return KSAMPLER + return KSAMPLER(sampler_function, extra_options, inpaint_options) def wrap_model(model): model_denoise = CFGNoisePredictor(model) @@ -618,11 +635,11 @@ def calculate_sigmas_scheduler(model, scheduler_name, steps): print("error invalid scheduler", self.scheduler) return sigmas -def sampler_class(name): +def sampler_object(name): if name == "uni_pc": - sampler = UNIPC + sampler = UNIPC() elif name == "uni_pc_bh2": - sampler = UNIPCBH2 + sampler = UNIPCBH2() elif name == "ddim": sampler = ksampler("euler", inpaint_options={"random": True}) else: @@ -687,6 +704,6 @@ class KSampler: else: return torch.zeros_like(noise) - sampler = sampler_class(self.sampler) + sampler = sampler_object(self.sampler) - return sample(self.model, noise, positive, negative, cfg, self.device, sampler(), sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) + return sample(self.model, noise, positive, negative, cfg, self.device, sampler, sigmas, self.model_options, latent_image=latent_image, denoise_mask=denoise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) diff --git a/backend/headless/fcbh/sd1_clip.py b/backend/headless/fcbh/sd1_clip.py index a5c710ac..83e03557 100644 --- a/backend/headless/fcbh/sd1_clip.py +++ b/backend/headless/fcbh/sd1_clip.py @@ -173,9 +173,9 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder): if getattr(self.transformer, self.inner_name).final_layer_norm.weight.dtype != torch.float32: precision_scope = torch.autocast else: - precision_scope = lambda a, b: contextlib.nullcontext(a) + precision_scope = lambda a, dtype: contextlib.nullcontext(a) - with precision_scope(model_management.get_autocast_device(device), torch.float32): + with precision_scope(model_management.get_autocast_device(device), dtype=torch.float32): attention_mask = None if self.enable_attention_masks: attention_mask = torch.zeros_like(tokens) diff --git a/backend/headless/fcbh/utils.py b/backend/headless/fcbh/utils.py index e1babd74..3632c4c3 100644 --- a/backend/headless/fcbh/utils.py +++ b/backend/headless/fcbh/utils.py @@ -307,13 +307,13 @@ def bislerp(samples, width, height): res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1] return res - def generate_bilinear_data(length_old, length_new): - coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + def generate_bilinear_data(length_old, length_new, device): + coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear") ratios = coords_1 - coords_1.floor() coords_1 = coords_1.to(torch.int64) - coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1 + coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1 coords_2[:,:,:,-1] -= 1 coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear") coords_2 = coords_2.to(torch.int64) @@ -323,7 +323,7 @@ def bislerp(samples, width, height): h_new, w_new = (height, width) #linear w - ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new) + ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device) coords_1 = coords_1.expand((n, c, h, -1)) coords_2 = coords_2.expand((n, c, h, -1)) ratios = ratios.expand((n, 1, h, -1)) @@ -336,7 +336,7 @@ def bislerp(samples, width, height): result = result.reshape(n, h, w_new, c).movedim(-1, 1) #linear h - ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new) + ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device) coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new)) ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new)) diff --git a/backend/headless/fcbh_extras/nodes_custom_sampler.py b/backend/headless/fcbh_extras/nodes_custom_sampler.py index 355926d6..2881f31e 100644 --- a/backend/headless/fcbh_extras/nodes_custom_sampler.py +++ b/backend/headless/fcbh_extras/nodes_custom_sampler.py @@ -16,7 +16,7 @@ class BasicScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -36,7 +36,7 @@ class KarrasScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -54,7 +54,7 @@ class ExponentialScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -73,7 +73,7 @@ class PolyexponentialScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -92,7 +92,7 @@ class VPScheduler: } } RETURN_TYPES = ("SIGMAS",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/schedulers" FUNCTION = "get_sigmas" @@ -109,7 +109,7 @@ class SplitSigmas: } } RETURN_TYPES = ("SIGMAS","SIGMAS") - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/sigmas" FUNCTION = "get_sigmas" @@ -118,6 +118,24 @@ class SplitSigmas: sigmas2 = sigmas[step:] return (sigmas1, sigmas2) +class FlipSigmas: + @classmethod + def INPUT_TYPES(s): + return {"required": + {"sigmas": ("SIGMAS", ), + } + } + RETURN_TYPES = ("SIGMAS",) + CATEGORY = "sampling/custom_sampling/sigmas" + + FUNCTION = "get_sigmas" + + def get_sigmas(self, sigmas): + sigmas = sigmas.flip(0) + if sigmas[0] == 0: + sigmas[0] = 0.0001 + return (sigmas,) + class KSamplerSelect: @classmethod def INPUT_TYPES(s): @@ -126,12 +144,12 @@ class KSamplerSelect: } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" def get_sampler(self, sampler_name): - sampler = fcbh.samplers.sampler_class(sampler_name)() + sampler = fcbh.samplers.sampler_object(sampler_name) return (sampler, ) class SamplerDPMPP_2M_SDE: @@ -145,7 +163,7 @@ class SamplerDPMPP_2M_SDE: } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -154,7 +172,7 @@ class SamplerDPMPP_2M_SDE: sampler_name = "dpmpp_2m_sde" else: sampler_name = "dpmpp_2m_sde_gpu" - sampler = fcbh.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type})() + sampler = fcbh.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "solver_type": solver_type}) return (sampler, ) @@ -169,7 +187,7 @@ class SamplerDPMPP_SDE: } } RETURN_TYPES = ("SAMPLER",) - CATEGORY = "sampling/custom_sampling" + CATEGORY = "sampling/custom_sampling/samplers" FUNCTION = "get_sampler" @@ -178,7 +196,7 @@ class SamplerDPMPP_SDE: sampler_name = "dpmpp_sde" else: sampler_name = "dpmpp_sde_gpu" - sampler = fcbh.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})() + sampler = fcbh.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r}) return (sampler, ) class SamplerCustom: @@ -234,6 +252,7 @@ class SamplerCustom: NODE_CLASS_MAPPINGS = { "SamplerCustom": SamplerCustom, + "BasicScheduler": BasicScheduler, "KarrasScheduler": KarrasScheduler, "ExponentialScheduler": ExponentialScheduler, "PolyexponentialScheduler": PolyexponentialScheduler, @@ -241,6 +260,6 @@ NODE_CLASS_MAPPINGS = { "KSamplerSelect": KSamplerSelect, "SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE, "SamplerDPMPP_SDE": SamplerDPMPP_SDE, - "BasicScheduler": BasicScheduler, "SplitSigmas": SplitSigmas, + "FlipSigmas": FlipSigmas, } diff --git a/fooocus_version.py b/fooocus_version.py index d1f7ccb9..2d26b9dd 100644 --- a/fooocus_version.py +++ b/fooocus_version.py @@ -1 +1 @@ -version = '2.1.810' +version = '2.1.811' diff --git a/modules/flags.py b/modules/flags.py index 48d5d08b..0392b397 100644 --- a/modules/flags.py +++ b/modules/flags.py @@ -10,7 +10,7 @@ uov_list = [ disabled, subtle_variation, strong_variation, upscale_15, upscale_2, upscale_fast ] -KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "dpm_2", "dpm_2_ancestral", +KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]