diff --git a/sgm/modules/diffusionmodules/openaimodel.py b/sgm/modules/diffusionmodules/openaimodel.py index ce16fd21..1e81752b 100644 --- a/sgm/modules/diffusionmodules/openaimodel.py +++ b/sgm/modules/diffusionmodules/openaimodel.py @@ -985,7 +985,7 @@ class UNetModel(nn.Module): if self.num_classes is not None: assert y.shape[0] == x.shape[0] - emb = emb + self.label_emb(y) + emb = emb + self.label_emb(y.to(x)) # h = x.type(self.dtype) h = x