diff --git a/sgm/modules/diffusionmodules/util.py b/sgm/modules/diffusionmodules/util.py index 6c42a42e..6783480a 100644 --- a/sgm/modules/diffusionmodules/util.py +++ b/sgm/modules/diffusionmodules/util.py @@ -272,8 +272,7 @@ class SiLU(nn.Module): class GroupNorm32(nn.GroupNorm): def forward(self, x): - self.weight = self.weight.float() - self.bias = self.bias.float() + self.to(torch.float32) return super().forward(x.float()).type(x.dtype)