Merge branch 'feature/add-tcd-sampler'

# Conflicts:
#	modules/flags.py
This commit is contained in:
Manuel Schmid 2024-05-15 17:17:02 +02:00
commit 4d46fb6bcb
No known key found for this signature in database
GPG Key ID: 32C4F7569B40B84B
8 changed files with 65 additions and 13 deletions

View File

@ -230,6 +230,25 @@ class SamplerDPMPP_SDE:
sampler = ldm_patched.modules.samplers.ksampler(sampler_name, {"eta": eta, "s_noise": s_noise, "r": r})
return (sampler, )
class SamplerTCD:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"eta": ("FLOAT", {"default": 0.3, "min": 0.0, "max": 1.0, "step": 0.01}),
}
}
RETURN_TYPES = ("SAMPLER",)
CATEGORY = "sampling/custom_sampling/samplers"
FUNCTION = "get_sampler"
def get_sampler(self, eta=0.3):
sampler = ldm_patched.modules.samplers.ksampler("tcd", {"eta": eta})
return (sampler, )
class SamplerCustom:
@classmethod
def INPUT_TYPES(s):
@ -292,6 +311,7 @@ NODE_CLASS_MAPPINGS = {
"KSamplerSelect": KSamplerSelect,
"SamplerDPMPP_2M_SDE": SamplerDPMPP_2M_SDE,
"SamplerDPMPP_SDE": SamplerDPMPP_SDE,
"SamplerTCD": SamplerTCD,
"SplitSigmas": SplitSigmas,
"FlipSigmas": FlipSigmas,
}

View File

@ -70,7 +70,7 @@ class ModelSamplingDiscrete:
@classmethod
def INPUT_TYPES(s):
return {"required": { "model": ("MODEL",),
"sampling": (["eps", "v_prediction", "lcm"],),
"sampling": (["eps", "v_prediction", "lcm", "tcd"]),
"zsnr": ("BOOLEAN", {"default": False}),
}}
@ -90,6 +90,9 @@ class ModelSamplingDiscrete:
elif sampling == "lcm":
sampling_type = LCM
sampling_base = ModelSamplingDiscreteDistilled
elif sampling == "tcd":
sampling_type = ldm_patched.modules.model_sampling.EPS
sampling_base = ModelSamplingDiscreteDistilled
class ModelSamplingAdvanced(sampling_base, sampling_type):
pass

View File

@ -752,7 +752,6 @@ def sample_lcm(model, x, sigmas, extra_args=None, callback=None, disable=None, n
return x
@torch.no_grad()
def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=None, s_churn=0., s_tmin=0., s_tmax=float('inf'), s_noise=1.):
# From MIT licensed: https://github.com/Carzit/sd-webui-samplers-scheduler/
@ -808,3 +807,30 @@ def sample_heunpp2(model, x, sigmas, extra_args=None, callback=None, disable=Non
d_prime = w1 * d + w2 * d_2 + w3 * d_3
x = x + d_prime * dt
return x
@torch.no_grad()
def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, noise_sampler=None, eta=0.3):
extra_args = {} if extra_args is None else extra_args
noise_sampler = default_noise_sampler(x) if noise_sampler is None else noise_sampler
s_in = x.new_ones([x.shape[0]])
model_sampling = model.inner_model.inner_model.model_sampling
timesteps_s = torch.floor((1 - eta) * model_sampling.timestep(sigmas)).to(dtype=torch.long).detach().cpu()
timesteps_s[-1] = 0
alpha_prod_s = model_sampling.alphas_cumprod[timesteps_s]
beta_prod_s = 1 - alpha_prod_s
for i in trange(len(sigmas) - 1, disable=disable):
denoised = model(x, sigmas[i] * s_in, **extra_args) # predicted_original_sample
eps = (x - denoised) / sigmas[i]
denoised = alpha_prod_s[i + 1].sqrt() * denoised + beta_prod_s[i + 1].sqrt() * eps
if callback is not None:
callback({"x": x, "i": i, "sigma": sigmas[i], "sigma_hat": sigmas[i], "denoised": denoised})
x = denoised
if eta > 0 and sigmas[i + 1] > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
return x

View File

@ -50,17 +50,17 @@ class ModelSamplingDiscrete(torch.nn.Module):
self.linear_start = linear_start
self.linear_end = linear_end
# self.register_buffer('betas', torch.tensor(betas, dtype=torch.float32))
# self.register_buffer('alphas_cumprod', torch.tensor(alphas_cumprod, dtype=torch.float32))
# self.register_buffer('alphas_cumprod_prev', torch.tensor(alphas_cumprod_prev, dtype=torch.float32))
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
self.set_sigmas(sigmas)
self.set_alphas_cumprod(alphas_cumprod.float())
def set_sigmas(self, sigmas):
self.register_buffer('sigmas', sigmas)
self.register_buffer('log_sigmas', sigmas.log())
def set_alphas_cumprod(self, alphas_cumprod):
self.register_buffer("alphas_cumprod", alphas_cumprod.float())
@property
def sigma_min(self):
return self.sigmas[0]

View File

@ -523,7 +523,7 @@ class UNIPCBH2(Sampler):
KSAMPLER_NAMES = ["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral",
"lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu",
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm"]
"dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "tcd"]
class KSAMPLER(Sampler):
def __init__(self, sampler_function, extra_options={}, inpaint_options={}):

View File

@ -819,19 +819,19 @@ def worker():
final_sampler_name = sampler_name
final_scheduler_name = scheduler_name
if scheduler_name == 'lcm':
if scheduler_name in ['lcm', 'tcd']:
final_scheduler_name = 'sgm_uniform'
if pipeline.final_unet is not None:
pipeline.final_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_unet,
sampling='lcm',
sampling=scheduler_name,
zsnr=False)[0]
if pipeline.final_refiner_unet is not None:
pipeline.final_refiner_unet = core.opModelSamplingDiscrete.patch(
pipeline.final_refiner_unet,
sampling='lcm',
sampling=scheduler_name,
zsnr=False)[0]
print('Using lcm scheduler.')
print(f'Using {scheduler_name} scheduler.')
async_task.yields.append(['preview', (13, 'Moving model to GPU ...', None)])

View File

@ -34,7 +34,8 @@ KSAMPLER = {
"dpmpp_3m_sde": "",
"dpmpp_3m_sde_gpu": "",
"ddpm": "",
"lcm": "LCM"
"lcm": "LCM",
"tcd": "TCD"
}
SAMPLER_EXTRA = {
@ -47,7 +48,7 @@ SAMPLERS = KSAMPLER | SAMPLER_EXTRA
KSAMPLER_NAMES = list(KSAMPLER.keys())
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps"]
SCHEDULER_NAMES = ["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform", "lcm", "turbo", "align_your_steps", "tcd"]
SAMPLER_NAMES = KSAMPLER_NAMES + list(SAMPLER_EXTRA.keys())
sampler_list = SAMPLER_NAMES

View File

@ -51,6 +51,8 @@ def patched_register_schedule(self, given_betas=None, beta_schedule="linear", ti
self.linear_end = linear_end
sigmas = torch.tensor(((1 - alphas_cumprod) / alphas_cumprod) ** 0.5, dtype=torch.float32)
self.set_sigmas(sigmas)
alphas_cumprod = torch.tensor(alphas_cumprod, dtype=torch.float32)
self.set_alphas_cumprod(alphas_cumprod)
return