Merge branch 'feature/add-tcd-sampler'
# Conflicts: # modules/flags.py
This commit is contained in:
commit
4d46fb6bcb
|
|
@ -230,6 +230,25 @@ class SamplerDPMPP_SDE:
|
|||
sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
|
||||
return (sampler, )
|
||||
|
||||
|
||||
class SamplerTCD:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
"required": {
|
||||
"eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
|
||||
}
|
||||
}
|
||||
RETURN_TYPES = ("SAMPLER",)
|
||||
CATEGORY = "sampling/custom_sampling/samplers"
|
||||
|
||||
FUNCTION = "get_sampler"
|
||||
|
||||
def get_sampler(self, eta=0.3):
|
||||
sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta})
|
||||
return (sampler, )
|
||||
|
||||
|
||||
class SamplerCustom:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
|
|
@ -292,6 +311,7 @@ NODE_CLASS_MAPPINGS = {
|
|||
"KSamplerSelect": KSamplerSelect,
|
||||
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
|
||||
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
|
||||
"SamplerTCD": SamplerTCD,
|
||||
"SplitSigmas": SplitSigmas,
|
||||
"FlipSigmas": FlipSigmas,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ class ModelSamplingDiscrete:
|
|||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "model": ("MODEL",),
|
||||
"sampling": (["eps", "v_prediction", "lcm"],),
|
||||
"sampling": (["eps", "v_prediction", "lcm", "tcd"]),
|
||||
"zsnr": ("BOOLEAN", {"default": False}),
|
||||
}}
|
||||
|
||||
|
|
@ -90,6 +90,9 @@ class ModelSamplingDiscrete:
|
|||
elif sampling == "lcm":
|
||||
sampling_type = LCM
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
elif sampling == "tcd":
|
||||
sampling_type = ldm_patched.modules.model_sampling.EPS
|
||||
sampling_base = ModelSamplingDiscreteDistilled
|
||||
|
||||
class ModelSamplingAdvanced(sampling_base, sampling_type):
|
||||
pass
|
||||
|
|
|
|||
|
|
@ -752,7 +752,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
|||
return x
|
||||
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
|
||||
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
|
||||
|
|
@ -808,3 +807,30 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
|
|||
d_prime = w1 * d + w2 * d_2 + w3 * d_3
|
||||
x = x + d_prime * dt
|
||||
return x
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3):
|
||||
extra_args = {} if extra_args is None else extra_args
|
||||
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
|
||||
s_in = x.new_ones([x.shape[0]])
|
||||
|
||||
model_sampling = model.inner_model.inner_model.model_sampling
|
||||
timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu()
|
||||
timesteps_s[-1] = 0
|
||||
alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s]
|
||||
beta_prod_s = 1 - alpha_prod_s
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample
|
||||
eps = (x - denoised) / sigmas[i]
|
||||
denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps
|
||||
|
||||
if callback is not None:
|
||||
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
|
||||
|
||||
x = denoised
|
||||
if eta > 0 and sigmas[i + 1] > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
|
||||
|
||||
return x
|
||||
|
|
@ -50,17 +50,17 @@ class ModelSamplingDiscrete(torch.nn.Module):
|
|||
self.linear_start = linear_start
|
||||
self.linear_end = linear_end
|
||||
|
||||
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
|
||||
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
|
||||
|
||||
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
||||
self.set_sigmas(sigmas)
|
||||
self.set_alphas_cumprod(alphas_cumprod.float())
|
||||
|
||||
def set_sigmas(self, sigmas):
|
||||
self.register_buffer('sigmas', sigmas)
|
||||
self.register_buffer('log_sigmas', sigmas.log())
|
||||
|
||||
def set_alphas_cumprod(self, alphas_cumprod):
|
||||
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
|
||||
|
||||
@property
|
||||
def sigma_min(self):
|
||||
return self.sigmas[0]
|
||||
|
|
|
|||
|
|
@ -523,7 +523,7 @@ class UNIPCBH2(Sampler):
|
|||
|
||||
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
|
||||
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
|
||||
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"]
|
||||
|
||||
class KSAMPLER(Sampler):
|
||||
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):
|
||||
|
|
|
|||
|
|
@ -819,19 +819,19 @@ def worker():
|
|||
final_sampler_name = sampler_name
|
||||
final_scheduler_name = scheduler_name
|
||||
|
||||
if scheduler_name == 'lcm':
|
||||
if scheduler_name in ['lcm', 'tcd']:
|
||||
final_scheduler_name = 'sgm_uniform'
|
||||
if pipeline.final_unet is not None:
|
||||
pipeline.final_unet = core.opModelSamplingDiscrete.patch(
|
||||
pipeline.final_unet,
|
||||
sampling='lcm',
|
||||
sampling=scheduler_name,
|
||||
zsnr=False)[0]
|
||||
if pipeline.final_refiner_unet is not None:
|
||||
pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
|
||||
pipeline.final_refiner_unet,
|
||||
sampling='lcm',
|
||||
sampling=scheduler_name,
|
||||
zsnr=False)[0]
|
||||
print('Using lcm scheduler.')
|
||||
print(f'Using {scheduler_name} scheduler.')
|
||||
|
||||
async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])
|
||||
|
||||
|
|
|
|||
|
|
@ -34,7 +34,8 @@ KSAMPLER = {
|
|||
"dpmpp_3m_sde": "",
|
||||
"dpmpp_3m_sde_gpu": "",
|
||||
"ddpm": "",
|
||||
"lcm": "LCM"
|
||||
"lcm": "LCM",
|
||||
"tcd": "TCD"
|
||||
}
|
||||
|
||||
SAMPLER_EXTRA = {
|
||||
|
|
@ -47,7 +48,7 @@ SAMPLERS = KSAMPLER | SAMPLER_EXTRA
|
|||
|
||||
KSAMPLER_NAMES = list(KSAMPLER.keys())
|
||||
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps"]
|
||||
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd"]
|
||||
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
|
||||
|
||||
sampler_list = SAMPLER_NAMES
|
||||
|
|
|
|||
|
|
@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti
|
|||
self.linear_end = linear_end
|
||||
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
|
||||
self.set_sigmas(sigmas)
|
||||
alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32)
|
||||
self.set_alphas_cumprod(alphas_cumprod)
|
||||
return
|
||||
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue