fix: correct sampling when gamma is 0 (#3093)
This commit is contained in:
parent
2d55a5f257
commit
b58bc7774e
|
|
@ -832,5 +832,7 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
|||
if eta > 0 and sigmas[i + 1] > 0:
|
||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
|
||||
else:
|
||||
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2)
|
||||
|
||||
return x
|
||||
Loading…
Reference in New Issue