This commit is contained in:
lvmin 2023-08-09 14:02:43 -07:00
parent 0b72417588
commit 8807b2dff8
1 changed files with 1 additions and 1 deletions

View File

@ -985,7 +985,7 @@ class UNetModel(nn.Module):
if self.num_classes is not None:
assert y.shape[0] == x.shape[0]
emb = emb + self.label_emb(y)
emb = emb + self.label_emb(y.to(x))
# h = x.type(self.dtype)
h = x