This commit is contained in:
lvmin 2023-08-09 14:00:59 -07:00
parent 573dc98eff
commit 0b72417588
2 changed files with 2 additions and 2 deletions

View File

@ -118,7 +118,7 @@ torch.cuda.empty_cache()
torch.cuda.ipc_collect()
shape = (1, 4, 128, 128)
randn = torch.randn(shape).cuda()
randn = torch.randn(shape).to(torch.float16).cuda()
def denoiser(input, sigma, c):

View File

@ -980,7 +980,7 @@ class UNetModel(nn.Module):
self.num_classes is not None
), "must specify y if and only if the model is class-conditional"
hs = []
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x)
emb = self.time_embed(t_emb)
if self.num_classes is not None: