fix: correct sampling when gamma is 0 (#3093)

This commit is contained in:
Manuel Schmid 2024-06-04 21:03:37 +02:00 committed by GitHub
parent 2d55a5f257
commit b58bc7774e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 2 additions and 0 deletions

View File

@ -832,5 +832,7 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n
if eta > 0 and sigmas[i + 1] > 0: if eta > 0 and sigmas[i + 1] > 0:
noise = noise_sampler(sigmas[i], sigmas[i + 1]) noise = noise_sampler(sigmas[i], sigmas[i + 1])
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt() x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
else:
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2)
return x return x