This commit is contained in:
parent
0b72417588
commit
8807b2dff8
|
|
@ -985,7 +985,7 @@ class UNetModel(nn.Module):
|
|||
|
||||
if self.num_classes is not None:
|
||||
assert y.shape[0] == x.shape[0]
|
||||
emb = emb + self.label_emb(y)
|
||||
emb = emb + self.label_emb(y.to(x))
|
||||
|
||||
# h = x.type(self.dtype)
|
||||
h = x
|
||||
|
|
|
|||
Loading…
Reference in New Issue