fix: correct sampling when gamma is 0 (#3093)
This commit is contained in:
parent
2d55a5f257
commit
b58bc7774e
|
|
@ -832,5 +832,7 @@ def sample_tcd(model, x, sigmas, extra_args=None, callback=None, disable=None, n
|
||||||
if eta > 0 and sigmas[i + 1] > 0:
|
if eta > 0 and sigmas[i + 1] > 0:
|
||||||
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
noise = noise_sampler(sigmas[i], sigmas[i + 1])
|
||||||
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
|
x = x / alpha_prod_s[i+1].sqrt() + noise * (sigmas[i+1]**2 + 1 - 1/alpha_prod_s[i+1]).sqrt()
|
||||||
|
else:
|
||||||
|
x *= torch.sqrt(1.0 + sigmas[i + 1] ** 2)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
Loading…
Reference in New Issue